src/kedro_databricks/init.py
from __future__ import annotations
import json
import logging
import re
import shutil
import tempfile
from pathlib import Path
import tomlkit
from kedro.framework.startup import ProjectMetadata
from kedro_databricks.utils import run_cmd
DEFAULT_NODE_TYPE_ID = {
"aws": "m5.xlarge",
"azure": "Standard_DS3_v2",
"gcp": "n1-standard-4",
}
_bundle_config_template = """
# This is a Databricks asset bundle definition for dab.
# See https://docs.databricks.com/dev-tools/bundles/index.html for documentation.
bundle:
name: {{ .project_slug }}
artifacts:
default:
type: whl
build: kedro package
path: .
include:
- resources/*.yml
- resources/**/*.yml
- resources/*.yaml
- resources/**/*.yaml
targets:
# The 'local' target, used for development purposes.
# Whenever a developer deploys using 'local', they get their own copy.
local:
# We use 'mode: development' to make sure everything deployed to this target gets a prefix
# like '[dev my_user_name]'. Setting this mode also disables any schedules and
# automatic triggers for jobs and enables the 'development' mode for Delta Live Tables pipelines.
mode: development
default: true
workspace:
host: {{workspace_host}}
# The 'prod' target, used for production deployment.
prod:
# For production deployments, we only have a single copy, so we override the
# workspace.root_path default of
# /Users/${workspace.current_user.userName}/.bundle/${bundle.target}/${bundle.name}
# to a path that is not specific to the current user.
#
# By making use of 'mode: production' we enable strict checks
# to make sure we have correctly configured this target.
mode: production
workspace:
host: {{workspace_host}}
root_path: /Shared/.bundle/prod/${bundle.name}
{{- if not is_service_principal}}
run_as:
# This runs as {{user_name}} in production. Alternatively,
# a service principal could be used here using service_principal_name
# (see Databricks documentation).
user_name: {{user_name}}
{{end -}}
"""
_bundle_init_template = {
"welcome_message": "Creating a Databricks asset bundle definition...",
"min_databricks_cli_version": "v0.212.2",
"properties": {
"project_name": {
"order": 1,
"type": "string",
"default": "kedro project",
"pattern": "^[^.\\\\/A-Z]{3,}$",
"pattern_match_failure_message": 'Project name must be at least 3 characters long and cannot contain the following characters: "\\", "/", " ", ".", and must be all lowercase letters.',
"description": "\nProject Name. Default",
},
"project_slug": {
"order": 2,
"type": "string",
"default": "{{ ((regexp `[- ]`).ReplaceAllString .project_name `_`) -}}",
"description": "\nProject slug. Default",
"hidden": True,
},
},
"success_message": "\n*** Asset Bundle successfully created for '{{.project_name}}'! ***",
}
_bundle_override_template = """
# Files named `databricks*` or `databricks/**` will be used to apply overrides to the
# generated asset bundle resources. The overrides should be specified according to the
# Databricks REST API's `Create a new job` endpoint. To learn more, visit their
# documentation at https://docs.databricks.com/api/workspace/jobs/create
{default_key}:
job_clusters:
- job_cluster_key: {default_key}
new_cluster:
spark_version: 14.3.x-scala2.12
node_type_id: {node_type_id}
num_workers: 1
spark_env_vars:
KEDRO_LOGGING_CONFIG: "/dbfs/FileStore/{package_name}/conf/logging.yml"
tasks:
- task_key: {default_key}
job_cluster_key: {default_key}
"""
_databricks_run_template = """
# This file is used to run Kedro pipelines on Databricks.
# It is automatically generated by the Kedro-Databricks plugin.
# Do not modify this file directly.
import argparse
import logging
from kedro.framework.project import configure_project
from kedro.framework.session import KedroSession
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", dest="env", type=str)
parser.add_argument("--conf-source", dest="conf_source", type=str)
parser.add_argument("--package-name", dest="package_name", type=str)
parser.add_argument("--nodes", dest="nodes", type=str)
args = parser.parse_args()
env = args.env
conf_source = args.conf_source
package_name = args.package_name
nodes = [node.strip() for node in args.nodes.split(",")]
# https://kb.databricks.com/notebooks/cmd-c-on-object-id-p0.html
logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
logging.getLogger("py4j.py4j.clientserver").setLevel(logging.ERROR)
configure_project(package_name)
with KedroSession.create(env=env, conf_source=conf_source) as session:
if len(nodes) > 0:
session.run(node_names=nodes)
else:
session.run()
if __name__ == "__main__":
main()
"""
def write_bundle_template(metadata: ProjectMetadata):
MSG = "Creating databricks configuration"
package_name = metadata.package_name
project_path = metadata.project_path
log = logging.getLogger(package_name)
if shutil.which("databricks") is None: # pragma: no cover
raise Exception("databricks CLI is not installed")
config = {
"project_name": package_name,
"project_slug": package_name,
}
assets_dir = tempfile.mkdtemp()
assets_dir = Path(assets_dir)
with open(assets_dir / "databricks_template_schema.json", "w") as f:
f.write(json.dumps(_bundle_init_template))
template_dir = assets_dir / "template"
template_dir.mkdir(exist_ok=True)
with open(f"{template_dir}/databricks.yml.tmpl", "w") as f:
f.write(_bundle_config_template)
template_params = tempfile.NamedTemporaryFile(delete=False)
template_params.write(json.dumps(config).encode())
template_params.close()
config_path = project_path / "databricks.yml"
if config_path.exists():
log.warning(f"{MSG}: {config_path.relative_to(project_path)} already exists.")
return
# We utilize the databricks CLI to create the bundle configuration.
# This is a bit hacky, but it allows the plugin to tap into the authentication
# mechanism of the databricks CLI and thereby avoid the need to store credentials
# in the plugin.
run_cmd(
[
"databricks",
"bundle",
"init",
assets_dir.as_posix(),
"--config-file",
template_params.name,
"--output-dir",
project_path.as_posix(),
],
msg=MSG,
)
log.info(f"{MSG}: Wrote {config_path.relative_to(project_path)}")
shutil.rmtree(assets_dir)
def write_override_template(metadata: ProjectMetadata, default_key: str, provider: str):
MSG = "Creating bundle override configuration"
package_name = metadata.package_name
project_path = metadata.project_path
log = logging.getLogger(package_name)
override_path = Path(project_path) / "conf" / "base" / "databricks.yml"
node_type_id = DEFAULT_NODE_TYPE_ID.get(provider, DEFAULT_NODE_TYPE_ID[provider])
if override_path.exists():
log.warning(f"{MSG}: {override_path.relative_to(project_path)} already exists.")
return
with open(override_path, "w") as f:
f.write(
_bundle_override_template.format(
default_key=default_key,
package_name=package_name,
node_type_id=node_type_id,
)
)
log.info(f"{MSG}: Wrote {override_path.relative_to(project_path)}")
def write_databricks_run_script(metadata: ProjectMetadata):
MSG = "Creating Databricks run script"
package_name = metadata.package_name
project_path = metadata.project_path
log = logging.getLogger(package_name)
script_path = project_path / "src" / package_name / "databricks_run.py"
toml_path = project_path / "pyproject.toml"
with open(script_path, "w") as f:
f.write(_databricks_run_template)
log.info(f"{MSG}: Wrote {script_path.relative_to(project_path)}")
with open(toml_path) as f:
toml = tomlkit.load(f)
scripts = toml.get("project", {}).get("scripts", {})
if "databricks_run" not in scripts:
scripts["databricks_run"] = f"{package_name}.databricks_run:main"
toml["project"]["scripts"] = scripts
log.info(f"{MSG}: Added script to {toml_path.relative_to(project_path)}")
with open(toml_path, "w") as f:
tomlkit.dump(toml, f)
def substitute_catalog_paths(metadata: ProjectMetadata):
MSG = "Substituting DBFS paths"
package_name = metadata.package_name
project_path = metadata.project_path
log = logging.getLogger(package_name)
conf_dir = metadata.project_path / "conf"
envs = [d for d in conf_dir.iterdir() if d.is_dir()]
regex = r"(.*/dbfs/FileStore/)(.*)(/data.*)"
for env in envs:
path = conf_dir / env / "catalog.yml"
log.info(f"{MSG}: Checking {path.relative_to(project_path)}")
if not path.exists():
log.warning(f"{MSG}: {path.relative_to(project_path)} does not exist.")
continue
with open(path) as f:
content = f.readlines()
new_content = _parse_content(metadata, regex, path, content)
with open(path, "w") as f:
f.writelines(new_content)
def _parse_content(metadata, regex, path, content):
package_name = metadata.package_name
project_path = metadata.project_path
log = logging.getLogger(package_name)
new_content = []
for line in content:
new_line = re.sub(regex, f"\\g<1>{package_name}\\g<3>", line)
if new_line != line:
log.info(
f"{path.relative_to(project_path)}: "
f"Substituted: {line.strip()} -> {new_line.strip()}"
)
new_content.append(new_line)
return new_content