getindata/kedro-airflow-k8s

View on GitHub
kedro_airflow_k8s/cli.py

Summary

Maintainability
A
3 hrs
Test Coverage
A
97%
import logging
import os
import webbrowser
from pathlib import Path
from typing import Dict, List, Optional

import click
from tabulate import tabulate

from kedro_airflow_k8s.airflow import AirflowClient
from kedro_airflow_k8s.cli_helper import CliHelper
from kedro_airflow_k8s.config import PluginConfig
from kedro_airflow_k8s.context_helper import ContextHelper
from kedro_airflow_k8s.template import (
    get_cron_expression,
    get_dag_filename_and_template_stream,
)


@click.group("airflow-k8s")
def commands():
    """Kedro plugin adding support for Airflow on K8S"""
    pass


@commands.group(
    name="airflow-k8s",
    context_settings=dict(help_option_names=["-h", "--help"]),
)
@click.option(
    "-e",
    "--env",
    "env",
    type=str,
    default=lambda: os.environ.get("KEDRO_ENV", "local"),
    help="Environment to use.",
)
@click.option(
    "-p",
    "--pipeline",
    "pipeline",
    type=str,
    default="__default__",
    required=False,
    help="Pipeline name to pick.",
)
@click.pass_obj
@click.pass_context
def airflow_group(ctx, metadata, env, pipeline):
    ctx.ensure_object(dict)
    ctx.obj["context_helper"] = ContextHelper.init(metadata, env, pipeline)


@airflow_group.command()
@click.option(
    "-i",
    "--image",
    "image",
    type=str,
    required=False,
    help="Image to override.",
)
@click.pass_context
def compile(ctx, image, target_path="dags/"):
    """Create an Airflow DAG for a project"""
    (
        dag_name,
        template_stream,
        spark_template_streams,
    ) = get_dag_filename_and_template_stream(
        ctx, image=image, cron_expression=get_cron_expression(ctx)
    )
    if spark_template_streams:
        CliHelper.dump_spark_artifacts(
            ctx, target_path, spark_template_streams
        )

    CliHelper.dump_templates(dag_name, target_path, template_stream)


@airflow_group.command()
@click.option(
    "-o",
    "--output",
    "output",
    type=str,
    required=False,
    help="Location where DAG file should be uploaded, for GCS use gs:// or "
    "gcs:// prefix, other notations indicate locally mounted filesystem",
)
@click.option(
    "-i",
    "--image",
    "image",
    type=str,
    required=False,
    help="Image to override.",
)
@click.pass_context
def upload_pipeline(ctx, output: str, image: str):
    """
    Uploads pipeline to Airflow DAG location
    """
    (
        dag_name,
        template_stream,
        spark_template_streams,
    ) = get_dag_filename_and_template_stream(
        ctx, image=image, cron_expression=get_cron_expression(ctx)
    )

    CliHelper.conditionally_handle_spark_artifacts(spark_template_streams, ctx)
    output = output or ctx.obj["context_helper"].config.output
    logging.info(f"Deploying to {output}")
    CliHelper.dump_templates(dag_name, output, template_stream)


@airflow_group.command()
@click.option(
    "-o",
    "--output",
    "output",
    type=str,
    required=False,
    help="Location where DAG file should be uploaded, for GCS use gs:// or "
    "gcs:// prefix, other notations indicate locally mounted filesystem",
)
@click.option(
    "-c",
    "--cron-expression",
    type=str,
    help="Cron expression for recurring run",
    required=False,
)
@click.option(
    "-d",
    "--dag-name",
    "dag_name",
    type=str,
    required=False,
    help="Allows overriding dag id and dag file name for a purpose of multiple variants"
    " of experiments",
)
@click.pass_context
def schedule(ctx, output: str, cron_expression: str, dag_name: str = None):
    """
    Uploads pipeline to Airflow with given schedule
    """
    (
        dag_filename,
        template_stream,
        spark_template_streams,
    ) = get_dag_filename_and_template_stream(
        ctx,
        cron_expression=get_cron_expression(ctx, cron_expression),
        dag_name=dag_name,
    )

    output = output or ctx.obj["context_helper"].config.output

    CliHelper.conditionally_handle_spark_artifacts(spark_template_streams, ctx)
    CliHelper.dump_templates(dag_filename, output, template_stream)


@airflow_group.command()
@click.option(
    "-o",
    "--output",
    "output",
    type=str,
    required=False,
    help="Location where DAG file should be uploaded, for GCS use gs:// or "
    "gcs:// prefix, other notations indicate locally mounted filesystem",
)
@click.option(
    "-d",
    "--dag-name",
    "dag_name",
    type=str,
    required=False,
    help="Allows overriding dag id and dag file name for a purpose of multiple variants"
    " of experiments",
)
@click.option(
    "-w",
    "--wait-for-completion",
    "wait_for_completion",
    type=int,
    required=False,
    help="If set, tells plugin to wait for dag run to finish and how long (minutes)",
)
@click.option(
    "-i",
    "--image",
    "image",
    type=str,
    required=False,
    help="Image to override.",
)
@click.pass_context
def run_once(
    ctx,
    output: Optional[str],
    dag_name: Optional[str],
    wait_for_completion: Optional[int],
    image: Optional[str],
):  # pylint: disable=too-many-arguments
    """
    Uploads pipeline to Airflow and runs once
    """
    (
        dag_filename,
        template_stream,
        spark_template_streams,
    ) = get_dag_filename_and_template_stream(
        ctx,
        dag_name=dag_name,
        image=image,
        cron_expression=None,
        with_external_dependencies=False,
    )
    context_helper = ctx.obj["context_helper"]
    output = output or context_helper.config.output

    CliHelper.conditionally_handle_spark_artifacts(spark_template_streams, ctx)
    CliHelper.dump_templates(dag_filename, output, template_stream)

    airflow_client = AirflowClient(context_helper.config.host)
    dag = airflow_client.wait_for_dag(
        dag_id=dag_name or context_helper.config.run_config.run_name,
        tag=f'commit_sha:{context_helper.session.store["git"]["commit_sha"]}',
    )
    dag_run_id = airflow_client.trigger_dag_run(dag.dag_id)

    if (wait_for_completion or 0) > 0:
        assert (
            airflow_client.wait_for_dag_run_completion(
                dag.dag_id, dag_run_id, wait_for_completion
            )
            == "success"
        )


@airflow_group.command()
@click.pass_context
def list_pipelines(ctx):
    """
    List pipelines generated by this plugin
    """

    context_helper = ctx.obj["context_helper"]
    airflow_client = AirflowClient(context_helper.config.host)

    dags = airflow_client.list_dags("generated_with_kedro_airflow_k8s")

    def name(tags: List[Dict[str, str]]) -> str:
        experiment_tag = [
            t["name"] for t in tags if t["name"].startswith("experiment_name")
        ][0]
        return experiment_tag[len("experiment_name") + 1 :]  # noqa: E203

    pipelines = [[name(d.tags), d.dag_id] for d in dags]
    pipelines.sort()
    click.echo(tabulate(pipelines, headers=["Name", "ID"]))


@airflow_group.command()
@click.option(
    "-d",
    "--dag-name",
    "dag_name",
    type=str,
    required=False,
    help="View for this specific DAG will be opened",
)
@click.pass_context
def ui(ctx, dag_name: Optional[str] = None):
    """Open Apache Airflow UI in new browser tab"""
    host = ctx.obj["context_helper"].config.host
    if dag_name:
        host = f"{host}/tree?dag_id={dag_name}"
    webbrowser.open_new_tab(host)


@airflow_group.command()
@click.argument("airflow_url", type=str)
@click.option("--with-github-actions", is_flag=True, default=False)
@click.option("--output", type=str, default=False)
@click.pass_context
def init(ctx, airflow_url: str, with_github_actions: bool, output: str):
    """Initializes configuration for the plugin"""
    context_helper = ctx.obj["context_helper"]
    project_name = context_helper.context.project_path.name
    if with_github_actions:
        image = f"gcr.io/${{google_project_id}}/{project_name}:${{commit_id}}"
        run_name = f"{project_name}:${{commit_id}}"
    else:
        image = project_name
        run_name = project_name

    sample_config = PluginConfig.sample_config(
        url=airflow_url,
        image=image,
        project=project_name,
        run_name=run_name,
        output=output,
    )
    config_path = Path.cwd().joinpath("conf/base/airflow-k8s.yaml")
    config_path.parent.mkdir(exist_ok=True, parents=True)
    with open(config_path, "w") as f:
        f.write(sample_config)

    click.echo(f"Configuration generated in {config_path}")

    if with_github_actions:
        PluginConfig.initialize_github_actions(
            project_name,
            where=Path.cwd(),
            templates_dir=Path(__file__).parent / "templates",
        )