JenspederM/kedro-databricks

View on GitHub
src/kedro_databricks/init.py

Summary

Maintainability
A
0 mins
Test Coverage
from __future__ import annotations

import json
import logging
import re
import shutil
import tempfile
from pathlib import Path

import tomlkit
from kedro.framework.startup import ProjectMetadata

from kedro_databricks.utils import run_cmd

DEFAULT_NODE_TYPE_ID = {
    "aws": "m5.xlarge",
    "azure": "Standard_DS3_v2",
    "gcp": "n1-standard-4",
}

_bundle_config_template = """
# This is a Databricks asset bundle definition for dab.
# See https://docs.databricks.com/dev-tools/bundles/index.html for documentation.
bundle:
  name: {{ .project_slug }}

artifacts:
  default:
    type: whl
    build: kedro package
    path: .

include:
  - resources/*.yml
  - resources/**/*.yml
  - resources/*.yaml
  - resources/**/*.yaml

targets:
  # The 'local' target, used for development purposes.
  # Whenever a developer deploys using 'local', they get their own copy.
  local:
    # We use 'mode: development' to make sure everything deployed to this target gets a prefix
    # like '[dev my_user_name]'. Setting this mode also disables any schedules and
    # automatic triggers for jobs and enables the 'development' mode for Delta Live Tables pipelines.
    mode: development
    default: true
    workspace:
      host: {{workspace_host}}

  # The 'prod' target, used for production deployment.
  prod:
    # For production deployments, we only have a single copy, so we override the
    # workspace.root_path default of
    # /Users/${workspace.current_user.userName}/.bundle/${bundle.target}/${bundle.name}
    # to a path that is not specific to the current user.
    #
    # By making use of 'mode: production' we enable strict checks
    # to make sure we have correctly configured this target.
    mode: production
    workspace:
      host: {{workspace_host}}
      root_path: /Shared/.bundle/prod/${bundle.name}
    {{- if not is_service_principal}}
    run_as:
      # This runs as {{user_name}} in production. Alternatively,
      # a service principal could be used here using service_principal_name
      # (see Databricks documentation).
      user_name: {{user_name}}
    {{end -}}
"""

_bundle_init_template = {
    "welcome_message": "Creating a Databricks asset bundle definition...",
    "min_databricks_cli_version": "v0.212.2",
    "properties": {
        "project_name": {
            "order": 1,
            "type": "string",
            "default": "kedro project",
            "pattern": "^[^.\\\\/A-Z]{3,}$",
            "pattern_match_failure_message": 'Project name must be at least 3 characters long and cannot contain the following characters: "\\", "/", " ", ".", and must be all lowercase letters.',
            "description": "\nProject Name. Default",
        },
        "project_slug": {
            "order": 2,
            "type": "string",
            "default": "{{ ((regexp `[- ]`).ReplaceAllString .project_name `_`) -}}",
            "description": "\nProject slug. Default",
            "hidden": True,
        },
    },
    "success_message": "\n*** Asset Bundle successfully created for '{{.project_name}}'! ***",
}

_bundle_override_template = """
# Files named `databricks*` or `databricks/**` will be used to apply overrides to the
# generated asset bundle resources. The overrides should be specified according to the
# Databricks REST API's `Create a new job` endpoint. To learn more, visit their
# documentation at https://docs.databricks.com/api/workspace/jobs/create

{default_key}:
    job_clusters:
        - job_cluster_key: {default_key}
          new_cluster:
              spark_version: 14.3.x-scala2.12
              node_type_id: {node_type_id}
              num_workers: 1
              spark_env_vars:
                  KEDRO_LOGGING_CONFIG: "/dbfs/FileStore/{package_name}/conf/logging.yml"
    tasks:
        - task_key: {default_key}
          job_cluster_key: {default_key}
"""

_databricks_run_template = """
# This file is used to run Kedro pipelines on Databricks.
# It is automatically generated by the Kedro-Databricks plugin.
# Do not modify this file directly.

import argparse
import logging

from kedro.framework.project import configure_project
from kedro.framework.session import KedroSession


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", dest="env", type=str)
    parser.add_argument("--conf-source", dest="conf_source", type=str)
    parser.add_argument("--package-name", dest="package_name", type=str)
    parser.add_argument("--nodes", dest="nodes", type=str)

    args = parser.parse_args()
    env = args.env
    conf_source = args.conf_source
    package_name = args.package_name
    nodes = [node.strip() for node in args.nodes.split(",")]

    # https://kb.databricks.com/notebooks/cmd-c-on-object-id-p0.html
    logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
    logging.getLogger("py4j.py4j.clientserver").setLevel(logging.ERROR)

    configure_project(package_name)
    with KedroSession.create(env=env, conf_source=conf_source) as session:
        if len(nodes) > 0:
            session.run(node_names=nodes)
        else:
            session.run()

if __name__ == "__main__":
    main()
"""


def write_bundle_template(metadata: ProjectMetadata):
    MSG = "Creating databricks configuration"
    package_name = metadata.package_name
    project_path = metadata.project_path
    log = logging.getLogger(package_name)
    if shutil.which("databricks") is None:  # pragma: no cover
        raise Exception("databricks CLI is not installed")

    config = {
        "project_name": package_name,
        "project_slug": package_name,
    }

    assets_dir = tempfile.mkdtemp()
    assets_dir = Path(assets_dir)
    with open(assets_dir / "databricks_template_schema.json", "w") as f:
        f.write(json.dumps(_bundle_init_template))

    template_dir = assets_dir / "template"
    template_dir.mkdir(exist_ok=True)
    with open(f"{template_dir}/databricks.yml.tmpl", "w") as f:
        f.write(_bundle_config_template)

    template_params = tempfile.NamedTemporaryFile(delete=False)
    template_params.write(json.dumps(config).encode())
    template_params.close()

    config_path = project_path / "databricks.yml"
    if config_path.exists():
        log.warning(f"{MSG}: {config_path.relative_to(project_path)} already exists.")
        return

    # We utilize the databricks CLI to create the bundle configuration.
    # This is a bit hacky, but it allows the plugin to tap into the authentication
    # mechanism of the databricks CLI and thereby avoid the need to store credentials
    # in the plugin.

    run_cmd(
        [
            "databricks",
            "bundle",
            "init",
            assets_dir.as_posix(),
            "--config-file",
            template_params.name,
            "--output-dir",
            project_path.as_posix(),
        ],
        msg=MSG,
    )
    log.info(f"{MSG}: Wrote {config_path.relative_to(project_path)}")
    shutil.rmtree(assets_dir)


def write_override_template(metadata: ProjectMetadata, default_key: str, provider: str):
    MSG = "Creating bundle override configuration"
    package_name = metadata.package_name
    project_path = metadata.project_path
    log = logging.getLogger(package_name)
    override_path = Path(project_path) / "conf" / "base" / "databricks.yml"
    node_type_id = DEFAULT_NODE_TYPE_ID.get(provider, DEFAULT_NODE_TYPE_ID[provider])

    if override_path.exists():
        log.warning(f"{MSG}: {override_path.relative_to(project_path)} already exists.")
        return

    with open(override_path, "w") as f:
        f.write(
            _bundle_override_template.format(
                default_key=default_key,
                package_name=package_name,
                node_type_id=node_type_id,
            )
        )
    log.info(f"{MSG}: Wrote {override_path.relative_to(project_path)}")


def write_databricks_run_script(metadata: ProjectMetadata):
    MSG = "Creating Databricks run script"
    package_name = metadata.package_name
    project_path = metadata.project_path
    log = logging.getLogger(package_name)
    script_path = project_path / "src" / package_name / "databricks_run.py"
    toml_path = project_path / "pyproject.toml"

    with open(script_path, "w") as f:
        f.write(_databricks_run_template)
    log.info(f"{MSG}: Wrote {script_path.relative_to(project_path)}")

    with open(toml_path) as f:
        toml = tomlkit.load(f)

    scripts = toml.get("project", {}).get("scripts", {})
    if "databricks_run" not in scripts:
        scripts["databricks_run"] = f"{package_name}.databricks_run:main"
        toml["project"]["scripts"] = scripts

    log.info(f"{MSG}: Added script to {toml_path.relative_to(project_path)}")
    with open(toml_path, "w") as f:
        tomlkit.dump(toml, f)


def substitute_catalog_paths(metadata: ProjectMetadata):
    MSG = "Substituting DBFS paths"
    package_name = metadata.package_name
    project_path = metadata.project_path
    log = logging.getLogger(package_name)
    conf_dir = metadata.project_path / "conf"
    envs = [d for d in conf_dir.iterdir() if d.is_dir()]
    regex = r"(.*/dbfs/FileStore/)(.*)(/data.*)"
    for env in envs:
        path = conf_dir / env / "catalog.yml"
        log.info(f"{MSG}: Checking {path.relative_to(project_path)}")

        if not path.exists():
            log.warning(f"{MSG}: {path.relative_to(project_path)} does not exist.")
            continue

        with open(path) as f:
            content = f.readlines()

        new_content = _parse_content(metadata, regex, path, content)

        with open(path, "w") as f:
            f.writelines(new_content)


def _parse_content(metadata, regex, path, content):
    package_name = metadata.package_name
    project_path = metadata.project_path
    log = logging.getLogger(package_name)
    new_content = []
    for line in content:
        new_line = re.sub(regex, f"\\g<1>{package_name}\\g<3>", line)
        if new_line != line:
            log.info(
                f"{path.relative_to(project_path)}: "
                f"Substituted: {line.strip()} -> {new_line.strip()}"
            )
        new_content.append(new_line)
    return new_content