getindata/kedro-airflow-k8s

View on GitHub
kedro_airflow_k8s/operators/node_pod.py

Summary

Maintainability
A
0 mins
Test Coverage
A
93%
"""
Module contains Apache Airflow operator that creates k8s pod for execution of
kedro node.
"""

import logging
from typing import Dict, List, Optional

from airflow.kubernetes.pod_generator import PodGenerator
from airflow.kubernetes.secret import Secret
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
    KubernetesPodOperator,
)
from kubernetes.client import models as k8s


class NodePodOperator(KubernetesPodOperator):
    """
    Operator starts pod with target image with kedro projects and executes one node from
    the pipeline. This class simplifies creation of pods by providing convenient options.
    """

    def __init__(  # pylint: disable=too-many-arguments
        self,
        node_name: str,
        namespace: str,
        image: str,
        image_pull_policy: str,
        env: str,
        task_id: str,
        pipeline: str = "__default__",
        pvc_name: Optional[str] = None,
        startup_timeout: int = 600,
        volume_disabled: bool = False,
        volume_owner: int = 0,
        mlflow_enabled: bool = True,
        requests_cpu: Optional[str] = None,
        requests_memory: Optional[str] = None,
        limits_cpu: Optional[str] = None,
        limits_memory: Optional[str] = None,
        node_selector_labels: Optional[Dict[str, str]] = None,
        labels: Optional[Dict[str, str]] = None,
        image_pull_secrets: Optional[str] = None,
        service_account_name: Optional[str] = "default",
        tolerations: Optional[List[Dict[str, str]]] = None,
        annotations: Optional[Dict[str, str]] = None,
        secrets: Optional[List[Secret]] = None,
        source: str = "/home/kedro/data",
        parameters: Optional[str] = "",
        kubernetes_pod_template: Optional[str] = None,
        **kwargs,
    ):
        """

        :param node_name: name from the kedro pipeline
        :param namespace: k8s namespace the pod will execute in
        :param pvc_name: name of the shared storage attached to this pod
        :param image: image to be mounted
        :param image_pull_policy: k8s image pull policy
        :param env: kedro pipeline configuration name, provided with '--env' option
        :param pipeline: kedro pipeline name, provided with '--pipeline' option
        :param task_id: Airflow id to override
        :param startup_timeout: after the amount provided in seconds the pod start is
                                timed out
        :param volume_disabled: if set to true, shared volume is not attached
        :param volume_owner: if volume is not disabled, fs group associated with this pod
        :param mlflow_enabled: if mlflow_run_id value is passed from xcom
        :param requests_cpu: k8s requests cpu value
        :param requests_memory: k8s requests memory value
        :param limits_cpu: k8s limits cpu value
        :param limits_memory: k8s limits memory value
        :param node_selector_labels: dictionary of node selector labels to be put into
                                     pod node selector
        :param labels: dictionary of labels to apply on pod
        :param image_pull_secrets: Any image pull secrets to be given to the pod.
                                   If more than one secret is required, provide a
                                   comma separated list: secret_a,secret_b
        :param service_account_name: Name of the service account
        :param tolerations: dictionary tolerations for nodes
        :param annotations: dictionary of annotations to apply on pod
        :param source: mount point of shared storage
        :param parameters: additional kedro run parameters
        """
        self._task_id = task_id
        self._volume_disabled = volume_disabled
        self._pvc_name = pvc_name
        self._mlflow_enabled = mlflow_enabled
        self._kubernetes_pod_template = kubernetes_pod_template

        super().__init__(
            task_id=task_id,
            security_context=self.create_security_context(
                volume_disabled, volume_owner
            ),
            namespace=namespace,
            image=image,
            image_pull_policy=image_pull_policy,
            image_pull_secrets=image_pull_secrets,
            service_account_name=service_account_name,
            arguments=[
                "kedro",
                "run",
                "--env",
                env,
                "--pipeline",
                pipeline,
                "--node",
                node_name,
                "--params",
                parameters,
            ],
            volume_mounts=[
                k8s.V1VolumeMount(mount_path=source, name="storage")
            ]
            if not volume_disabled
            else [],
            resources=self.create_resources(
                requests_cpu, requests_memory, limits_cpu, limits_memory
            ),
            startup_timeout_seconds=startup_timeout,
            is_delete_operator_pod=True,
            pod_template_file=self.minimal_pod_template,
            node_selectors=node_selector_labels,
            labels=labels,
            tolerations=self.create_tolerations(tolerations),
            annotations=annotations,
            secrets=secrets,
            **kwargs,
        )

    def execute(self, context):
        """
        Executes task in pod with provided configuration (super implementation used).
        :param context:
        :return:
        """
        logging.debug(self.create_pod_request_obj())
        return super().execute(context)

    @staticmethod
    def create_resources(
        requests_cpu, requests_memory, limits_cpu, limits_memory
    ) -> k8s.V1ResourceRequirements:
        """
        Creates k8s resources based on requests and limits
        :param requests_cpu:
        :param requests_memory:
        :param limits_cpu:
        :param limits_memory:
        :return:
        """
        requests = {}
        if requests_cpu:
            requests["cpu"] = requests_cpu
        if requests_memory:
            requests["memory"] = requests_memory
        limits = {}
        if limits_cpu:
            limits["cpu"] = limits_cpu
        if limits_memory:
            limits["memory"] = limits_memory
        return k8s.V1ResourceRequirements(limits=limits, requests=requests)

    @property
    def minimal_pod_template(self):
        """
        This template is required since 'volumes' arguments are not templated via direct
        API nor passing xcom values in pod definition.
        :return: partial pod definition that should be complemented by other operator
                parameters
        """
        if self._kubernetes_pod_template:
            return self._kubernetes_pod_template
        minimal_pod_template = f"""
apiVersion: v1
kind: Pod
metadata:
  name: {PodGenerator.make_unique_pod_id(self._task_id)}
spec:
  containers:
    - name: base
      env:
"""
        if self._mlflow_enabled:
            minimal_pod_template += """
        - name: MLFLOW_RUN_ID
          value: {{ task_instance.xcom_pull(key="mlflow_run_id") }}
"""
        if not self._volume_disabled:
            minimal_pod_template += f"""
  volumes:
    - name: storage
      persistentVolumeClaim:
        claimName: {self._pvc_name}
"""
        return minimal_pod_template

    @staticmethod
    def create_security_context(
        volume_disabled: bool, volume_owner: int
    ) -> k8s.V1PodSecurityContext:
        """
        Creates security context based on volume information
        :param volume_disabled:
        :param volume_owner:
        :return:
        """
        return (
            k8s.V1PodSecurityContext(fs_group=volume_owner)
            if not volume_disabled
            else k8s.V1PodSecurityContext()
        )

    @staticmethod
    def create_tolerations(
        tolerations: Optional[List[Dict[str, str]]] = None
    ) -> List[k8s.V1Toleration]:
        """
        Creates k8s tolerations
        :param tolerations:
        :return:
        """
        if not tolerations:
            return []
        return [
            k8s.V1Toleration(
                effect=toleration.get("effect"),
                key=toleration.get("key"),
                operator=toleration.get("operator"),
                value=toleration.get("value"),
            )
            for toleration in tolerations
        ]