kedro_airflow_k8s/config.py
import os
from kedro.config import MissingConfigException
DEFAULT_CONFIG_TEMPLATE = """
# Base url of the Apache Airflow, should include the schema (http/https)
host: {url}
# Directory from where Apache Airflow is reading DAGs definitions
output: {output}
# Configuration used to run the pipeline
run_config:
# Name of the image to run as the pipeline steps
image: {image}
# Pull policy to be used for the steps. Use Always if you push the images
# on the same tag, or Never if you use only local images
image_pull_policy: IfNotPresent
# Pod startup timeout in seconds
startup_timeout: 600
# Namespace for Airflow pods to be created
namespace: airflow
# Name of the Airflow experiment to be created
experiment_name: {project}
# Name of the dag as it's presented in Airflow
run_name: {run_name}
# Apache Airflow cron expression for scheduled runs
cron_expression: "@daily"
# Optional start date in format YYYYMMDD
#start_date: "20210721"
# Optional pipeline description
#description: "Very Important Pipeline"
# Comma separated list of image pull secret names
#image_pull_secrets: my-registry-credentials
# Service account name to execute nodes with
#service_account_name: default
# List of handlers executed after task failure
failure_handlers: []
# Optional volume specification
volume:
# Storage class - use null (or no value) to use the default storage
# class deployed on the Kubernetes cluster
storageclass: # default
# The size of the volume that is created. Applicable for some storage
# classes
size: 1Gi
# Access mode of the volume used to exchange data. ReadWriteMany is
# preferred, but it is not supported on some environements (like GKE)
# Default value: ReadWriteOnce
#access_modes: [ReadWriteMany]
# Flag indicating if the data-volume-init step (copying raw data to the
# fresh volume) should be skipped
skip_init: False
# Allows to specify fsGroup executing pipelines within containers
# Default: root user group (to avoid issues with volumes in GKE)
owner: 0
# Tells if volume should not be used at all, false by default
disabled: False
# List of optional secrets specification
secrets:
# deploy_type: The type of secret deploy in Kubernetes, either `env` or
# `volume`
- deploy_type: "env"
# deploy_target: (Optional) The environment variable when `deploy_type` `env`
# or file path when `deploy_type` `volume` where expose secret. If `key` is
# not provided deploy target should be None.
deploy_target: "SQL_CONN"
# secret: Name of the secrets object in Kubernetes
secret: "airflow-secrets"
# key: (Optional) Key of the secret within the Kubernetes Secret if not
# provided in `deploy_type` `env` it will mount all secrets in object
key: "sql_alchemy_conn"
# Apache Airflow macros to be exposed for the parameters
# List of macros can be found here:
# https://airflow.apache.org/docs/apache-airflow/stable/macros-ref.html
macro_params: [ds, prev_ds]
# Apache Airflow variables to be exposed for the parameters
variables_params: [env]
# Spark nodes are grouped by default to benefit from lazy execution.
# If you want to disable this behaviour, set the following value to False.
group_spark_nodes: True
# Optional resources specification
#resources:
# Default configuration used by all nodes that do not declare the
# resource configuration. It's optional. If node does not declare the resource
# configuration, __default__ is assigned by default, otherwise cluster defaults
# will be used.
#__default__:
# Optional labels to be put into pod node selector
#node_selectors:
#Labels are user provided key value pairs
#node_pool_label/k8s.io: example_value
# Optional labels to apply on pods
#labels:
#running: airflow
# Optional annotations to apply on pods
#annotations:
#iam.amazonaws.com/role: airflow
#vault.hashicorp.com/agent-inject-template-foo: |
# {{- with secret "database/creds/db-app" -}}
# postgres://{{ .Data.username }}:{{ .Data.password }}@postgres:5432/mydb
# {{- end }}
# Optional list of kubernetes tolerations
#tolerations:
#- key: "group"
#value: "data-processing"
#effect: "NoExecute"
#- key: "group"
#operator: "Equal",
#value: "data-processing",
#effect: "NoSchedule"
#requests:
#Optional amount of cpu resources requested from k8s
#cpu: "1"
#Optional amount of memory resource requested from k8s
#memory: "1Gi"
#limits:
#Optional amount of cpu resources limit on k8s
#cpu: "1"
#Optional amount of memory resource limit on k8s
#memory: "1Gi"
# Other arbitrary configurations to use
#custom_resource_config_name:
# Optional labels to be put into pod node selector
#labels:
#Labels are user provided key value pairs
#label_key: label_value
#requests:
#Optional amount of cpu resources requested from k8s
#cpu: "1"
#Optional amount of memory resource requested from k8s
#memory: "1Gi"
#limits:
#Optional amount of cpu resources limit on k8s
#cpu: "1"
#Optional amount of memory resource limit on k8s
#memory: "1Gi"
# Optional external dependencies configuration
#external_dependencies:
# Can just select dag as a whole
#- dag_id: upstream-dag
# or detailed
#- dag_id: another-upstream-dag
# with specific task to wait on
# task_id: with-precise-task
# Maximum time (minute) to wait for the external dag to finish before this
# pipeline fails, the default is 1440 == 1 day
# timeout: 2
# Checks if the external dag exists before waiting for it to finish. If it
# does not exists, fail this pipeline. By default is set to true.
# check_existence: False
# Time difference with the previous execution to look at (minutes),
# the default is 0 meaning no difference
# execution_delta: 10
# Optional authentication to MLflow API
#authentication:
# Strategy that generates the credentials, supported values are:
# - Null
# - GoogleOAuth2 (generating OAuth2 tokens for service account provided by
# GOOGLE_APPLICATION_CREDENTIALS)
# - Vars (credentials fetched from airflow Variable.get - specify variable keys,
# matching MLflow authentication env variable names, in `params`,
# e.g. ["MLFLOW_TRACKING_USERNAME", "MLFLOW_TRACKING_PASSWORD"])
#type: GoogleOAuth2
#params: []
#spark:
# submit_job_operator:
# Airflow operator to use for submitting Spark job: SparkSubmitOperator,
# DataprocSubmitJobOperator or KubernetesSparkOperator
# region: None
# project_id: None
# cluster_name: None
# create_cluster: False
# Optional custom kubermentes pod templates applied on nodes basis
#kubernetes_pod_templates:
# Name of the node you want to apply the custom template to.
# if you specify __default__, this template will be applied to all nodes.
# Otherwise it will be only applied to nodes tagged with `k8s_template:<node_name>`
# node_name:
# Kubernetes pod template.
# It's the full content of the pod-template file (as a string)
# `run_config.volume` and `MLFLOW_RUN_ID` env are disabled when this is set.
# Note: python F-string formatting is applied to this string, so
# you can also use some dynamic values, e.g. to calculate pod name.
# template:
# Optionally, you can also override the image
# image:
# ____ EXAMPLE _______________
#
#kubernetes_pod_templates:
# spark:
# template: |-
# apiVersion: v1
# kind: Pod
# metadata:
# name: newname
# spec:
# containers:
# - name: base
# env:
# - name: CUSTOM_ENV
# value: env1
#
# Configuration for spark jobs
# spark:
# type: k8s
# cluster_name: spark_k8s
# run_script: local:///path/to/run/script/in/image.py
"""
class Config(object):
def __init__(self, raw):
self._raw = raw
def _get_or_default(self, prop, default):
return self._raw.get(prop, default)
def _get_or_fail(self, prop):
if prop in self._raw.keys():
return self._raw[prop]
else:
raise MissingConfigException(
f"Missing required configuration: '{self._get_prefix()}{prop}'."
)
def _get_prefix(self):
return ""
def __eq__(self, other):
return self._raw == other._raw
class ResourceNodeConfig(Config):
@property
def cpu(self):
return self._get_or_default("cpu", None)
@property
def memory(self):
return self._get_or_default("memory", None)
class ResourceConfig(Config):
@property
def annotations(self):
return self._get_or_default("annotations", {})
@property
def tolerations(self):
return self._get_or_default("tolerations", {})
@property
def node_selectors(self):
return self._get_or_default("node_selectors", {})
@property
def labels(self):
return self._get_or_default("labels", {})
@property
def requests(self):
return ResourceNodeConfig(self._get_or_default("requests", {}))
@property
def limits(self):
return ResourceNodeConfig(self._get_or_default("limits", {}))
class ResourcesConfig(Config):
def __getattr__(self, item):
return self[item]
def __getitem__(self, item):
return ResourceConfig(self._get_or_default(item, {}))
class ExternalDependencyConfig(Config):
@property
def dag_id(self):
return self._get_or_fail("dag_id")
@property
def task_id(self):
return self._get_or_default("task_id", None)
@property
def check_existence(self):
return self._get_or_default("check_existence", True)
@property
def execution_delta(self):
return self._get_or_default("execution_delta", 0)
@property
def timeout(self):
return self._get_or_default("timeout", 60 * 24)
class AuthenticationConfig(Config):
@property
def type(self):
return self._get_or_default("type", "Null")
@property
def params(self):
return self._get_or_default("params", [])
class SparkConfig(Config):
@property
def type(self):
return self._get_or_default("type", "none")
@property
def region(self):
return self._get_or_default("region", "None")
@property
def cluster_name(self):
return self._get_or_default("cluster_name", "None")
@property
def project_id(self):
return self._get_or_default("project_id", "None")
@property
def operator_factory(self):
return self._get_or_default("operator_factory", None)
@property
def artifacts_path(self):
return self._get_or_default("artifacts_path", None)
@property
def user_init_path(self):
return self._get_or_default("user_init_path", None)
@property
def user_post_init_path(self):
return self._get_or_default("user_post_init_path", None)
@property
def cluster_config(self):
data = self._get_or_default("cluster_config", {})
if self.type in ["k8s", "kubernetes"]:
return SparkK8SConfig(data)
return data
@property
def requires_artifacts_dump(self):
return self.type not in ["k8s", "kubernetes"]
class SparkK8SStorageConfig(Config):
@property
def class_name(self):
return self._get_or_default("class_name", "standard")
@property
def size(self):
return self._get_or_default("size", None)
class SparkK8SConfig(Config):
@property
def run_script(self):
return self._get_or_fail("run_script")
@property
def image(self):
return self._get_or_default("image", None)
@property
def conf(self):
return self._get_or_default("conf", {})
@property
def driver_port(self):
return self._get_or_default("driver_port", None)
@property
def block_manager_port(self):
return self._get_or_default("block_manager_port", None)
@property
def secrets(self):
return self._get_or_default("secrets", {})
@property
def labels(self):
return self._get_or_default("labels", {})
@property
def local_storage(self):
return SparkK8SStorageConfig(self._get_or_default("local_storage", {}))
@property
def env_vars(self):
return self._get_or_default("env_vars", {})
@property
def requests(self):
return ResourceNodeConfig(self._get_or_default("requests", {}))
@property
def limits(self):
return ResourceNodeConfig(self._get_or_default("limits", {}))
@property
def num_executors(self):
return self._get_or_default("num_executors", "1")
@property
def jars(self):
return self._get_or_default("jars", None)
@property
def repositories(self):
return self._get_or_default("repositories", None)
@property
def packages(self):
return self._get_or_default("packages", None)
class RunConfig(Config):
@property
def image(self):
return self._get_or_fail("image")
@property
def image_pull_policy(self):
return self._get_or_default("image_pull_policy", "IfNotPresent")
@property
def startup_timeout(self):
return self._get_or_default("startup_timeout", 600)
@property
def namespace(self):
return self._get_or_fail("namespace")
@property
def failure_handlers(self):
cfg = self._get_or_default("failure_handlers", [])
supported_types = FailureHandlerConfig.supported_types()
return [
FailureHandlerConfig(handler)
for handler in cfg
if handler["type"] in supported_types
]
@property
def experiment_name(self):
return self._get_or_fail("experiment_name")
@property
def run_name(self):
return self._get_or_default("run_name", self.experiment_name)
@property
def cron_expression(self):
return self._get_or_default("cron_expression", "@daily")
@property
def start_date(self):
start_date = self._get_or_default("start_date", None)
if start_date:
start_date = str(start_date)
return start_date
@property
def description(self):
return self._get_or_default("description", None)
@property
def volume(self):
cfg = self._get_or_default("volume", {})
return VolumeConfig(cfg)
@property
def secrets(self):
cfg = self._get_or_default("secrets", [])
return [SecretConfig(secret) for secret in cfg]
@property
def macro_params(self):
return self._get_or_default("macro_params", [])
@property
def variables_params(self):
return self._get_or_default("variables_params", [])
@property
def resources(self):
cfg = self._get_or_default("resources", {})
return ResourcesConfig(cfg)
@property
def external_dependencies(self):
deps = self._get_or_default("external_dependencies", [])
return [ExternalDependencyConfig(cfg) for cfg in deps]
@property
def image_pull_secrets(self):
return self._get_or_default("image_pull_secrets", None)
@property
def service_account_name(self):
return self._get_or_default("service_account_name", None)
@property
def auth_config(self):
cfg = self._get_or_default(
"authentication", {"type": "Null", "params": []}
)
return AuthenticationConfig(cfg)
@property
def spark(self):
cfg = self._get_or_default("spark", {})
return SparkConfig(cfg)
@property
def env_vars(self):
return self._get_or_default("env_vars", [])
@property
def kubernetes_pod_templates(self):
cfg = self._get_or_default("kubernetes_pod_templates", {})
return KubernetesPodTemplates(cfg)
@property
def group_spark_nodes(self):
return self._get_or_default("group_spark_nodes", False)
def _get_prefix(self):
return "run_config."
class VolumeConfig(Config):
@property
def disabled(self):
return self._get_or_default("disabled", False)
@property
def storageclass(self):
return self._get_or_default("storageclass", None)
@property
def size(self):
return self._get_or_default("size", "1Gi")
@property
def access_modes(self):
return self._get_or_default("access_modes", ["ReadWriteOnce"])
@property
def skip_init(self):
return self._get_or_default("skip_init", False)
@property
def owner(self):
return self._get_or_default("owner", 0)
def _get_prefix(self):
return "run_config.volume."
class FailureHandlerConfig(Config):
@staticmethod
def supported_types():
return ["slack"]
@property
def type(self):
return self._get_or_fail("type")
@property
def connection_id(self):
return self._get_or_fail("connection_id")
@property
def message_template(self):
return self._get_or_default("message_template", "Task failed!")
class SecretConfig(Config):
@property
def deploy_type(self):
return self._get_or_default("deploy_type", "env")
@property
def deploy_target(self):
return self._get_or_default("deploy_target", None)
@property
def secret(self):
return self._get_or_fail("secret")
@property
def key(self):
return self._get_or_default("key", None)
def _get_prefix(self):
return "run_config.secrets."
class PluginConfig(Config):
@property
def host(self):
return self._get_or_fail("host")
@property
def output(self):
return self._get_or_fail("output")
@property
def run_config(self):
cfg = self._get_or_default("run_config", {})
return RunConfig(cfg)
@staticmethod
def sample_config(**kwargs):
return DEFAULT_CONFIG_TEMPLATE.format(**kwargs)
@staticmethod
def initialize_github_actions(project_name, where, templates_dir):
os.makedirs(where / ".github/workflows", exist_ok=True)
for template in ["on-merge-to-master.yml", "on-push.yml"]:
file_path = where / ".github/workflows" / template
template_file = templates_dir / f"github-{template}"
with open(template_file, "r") as tfile, open(file_path, "w") as f:
f.write(tfile.read().format(project_name=project_name))
class KubernetesPodTemplate(Config):
@property
def template(self):
return self._get_or_default("template", None)
@property
def image(self):
return self._get_or_default("image", None)
def __len__(self):
return len(self._raw)
class KubernetesPodTemplates(Config):
def __getattr__(self, item):
return self[item]
def __getitem__(self, item):
return KubernetesPodTemplate(self._get_or_default(item, {}))
def __len__(self):
return len(self._raw)