CSCfi/pebbles

View on GitHub
pebbles/worker/controllers.py

Summary

Maintainability
D
1 day
Test Coverage
import logging
import os
import time
import traceback
from random import randrange

import requests

from pebbles.models import ApplicationSession, Task
from pebbles.utils import find_driver_class

WS_CONTROLLER_TASK_LOCK_NAME = 'workspace-controller-tasks'

SESSION_CONTROLLER_LIMIT_SIZE = 50

DRIVER_CACHE_LIFETIME = 900


class ControllerBase:
    def __init__(self, worker_id, config, cluster_config, client, controller_name):
        self.worker_id = worker_id
        self.config = config
        self.cluster_config = cluster_config
        self.client = client
        self.controller_name = controller_name
        self.next_check_ts = 0

    def get_driver(self, cluster_name):
        """Create driver instance for given cluster.
        We cache the driver instances to avoid login for every new request"""
        cluster = None
        for c in self.cluster_config['clusters']:
            if c.get('name') == cluster_name:
                cluster = c
                break
        if cluster is None:
            raise RuntimeWarning('No matching cluster in configuration for %s' % cluster_name)

        # check cache
        if 'driver_instance' in cluster.keys():
            # we found an existing instance, use that if it is still valid
            driver_instance = cluster.get('driver_instance')
            if driver_instance.create_ts + DRIVER_CACHE_LIFETIME > time.time() and not driver_instance.is_expired():
                return driver_instance

        # create the driver by finding out the class and creating an instance
        driver_class = find_driver_class(cluster.get('driver'))
        if not driver_class:
            raise RuntimeWarning('No matching driver %s found for %s' % (cluster.get('driver'), cluster_name))

        # create an instance, test the connection and populate the cache
        driver_instance = driver_class(logging.getLogger(), self.config, cluster, self.client.token)
        driver_instance.connect()
        cluster['driver_instance'] = driver_instance

        return driver_instance

    def update_next_check_ts(self, polling_interval_min, polling_interval_max):
        self.next_check_ts = time.time() + randrange(polling_interval_min, polling_interval_max + 1)

    def get_polling_interval(self, default_min, default_max):
        """
        Read the polling interval from worker environment variables, if present. If not present,
        use given controller specific default values.
        """
        polling_interval_min = int(os.getenv(f"{self.controller_name}_POLLING_INTERVAL_SEC_MIN", default_min))
        polling_interval_max = int(os.getenv(f"{self.controller_name}_POLLING_INTERVAL_SEC_MAX", default_max))
        logging.info(f"{self.controller_name}_POLLING_INTERVAL_SEC_MIN is set to {polling_interval_min}")
        logging.info(f"{self.controller_name}_POLLING_INTERVAL_SEC_MAX is set to {polling_interval_max}")
        if polling_interval_min > polling_interval_max:
            logging.warning(f"{self.controller_name}_POLLING_INTERVAL_SEC_MIN is larger than "
                            f"{self.controller_name}_POLLING_INTERVAL_SEC_MAX, using default values instead")
            return default_min, default_max
        return polling_interval_min, polling_interval_max


class ApplicationSessionController(ControllerBase):
    """
    Controller that takes care of application sessions
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.polling_interval_min, self.polling_interval_max = self.get_polling_interval(2, 5)

    def update_application_session(self, application_session):
        logging.debug('updating %s' % application_session)
        application_session_id = application_session['id']
        cluster_name = application_session['provisioning_config']['cluster']
        if cluster_name is None:
            logging.warning(
                'Cluster/driver config for the application session %s is not found',
                application_session.get('name')
            )

        driver_application_session = self.get_driver(cluster_name)
        driver_application_session.test_connection()
        driver_application_session.update(self.client.token, application_session_id)

    def process_application_session(self, application_session):
        # check if we need to deprovision the application session
        if application_session.get('state') in [ApplicationSession.STATE_RUNNING]:
            if not application_session.get('lifetime_left') and application_session.get('maximum_lifetime'):
                logging.info(
                    'deprovisioning triggered for %s (reason: maximum lifetime exceeded)',
                    application_session.get('id')
                )
                self.client.do_application_session_patch(
                    application_session['id'], json_data={'to_be_deleted': True})

        self.update_application_session(application_session)

    def process(self):
        # process sessions in increased intervals
        if time.time() < self.next_check_ts:
            return
        self.update_next_check_ts(self.polling_interval_min, self.polling_interval_max)

        # Query all non-deleted application sessions. This will be a list of candidates, because other
        # workers could fetch the overlapping sessions as well.
        sessions = self.client.get_application_sessions(limit=SESSION_CONTROLLER_LIMIT_SIZE)
        logging.debug('got %d sessions', len(sessions))

        # extract sessions that need to be processed
        # waiting to be provisioned
        queueing_sessions = filter(lambda x: x['state'] == ApplicationSession.STATE_QUEUEING, sessions)
        # starting asynchronously
        starting_sessions = filter(lambda x: x['state'] == ApplicationSession.STATE_STARTING, sessions)
        # log fetching needed
        log_fetch_application_sessions = filter(
            lambda x: x['state'] == ApplicationSession.STATE_RUNNING and x['log_fetch_pending'], sessions)
        # expired sessions in need of deprovisioning
        expired_sessions = filter(
            lambda x: x['to_be_deleted'] or (x['lifetime_left'] == 0 and x['maximum_lifetime']),
            sessions
        )

        # process sessions that need action
        processed_sessions = []
        processed_sessions.extend(queueing_sessions)
        processed_sessions.extend(starting_sessions)
        processed_sessions.extend(expired_sessions)
        processed_sessions.extend(log_fetch_application_sessions)

        if len(processed_sessions):
            # get locks for sessions that are already being processed by another worker
            locks = self.client.query_locks()
            locked_session_ids = [lock['id'] for lock in locks]

            # delete leftover locks that we own
            for lock in locks:
                if lock['owner'] == self.worker_id:
                    self.client.release_lock(lock['id'], self.worker_id)

            for session in processed_sessions:
                # skip the ones that are already in progress
                if session['id'] in locked_session_ids:
                    logging.debug('skipping locked session %s', session['id'])
                    continue

                # try to obtain a lock. Should we lose the race, the winner takes it and we move on
                lock_id = self.client.obtain_lock(session.get('id'), self.worker_id)
                if not lock_id:
                    logging.debug('failed to acquire lock on session %s, skipping', session['id'])
                    continue

                # process session and release the lock
                try:
                    # Now we have the lock, and we can fetch the definite state for the session
                    # If the session has been already deleted by another worker, we'll get None
                    fresh_session = self.client.get_application_session(session.get('id'), suppress_404=True)
                    if fresh_session and fresh_session.get('state') == session.get('state'):
                        self.process_application_session(fresh_session)
                    else:
                        logging.info('session %s already processed by another worker', session.get('name'))
                except Exception as e:
                    logging.warning(e)
                    logging.debug(traceback.format_exc().splitlines()[-5:])
                finally:
                    self.client.release_lock(lock_id, self.worker_id)


class ClusterController(ControllerBase):
    """
    Controller that takes care of cluster resources
    The only task at the moment is to fetch and publish alerts.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.polling_interval_min, self.polling_interval_max = self.get_polling_interval(30, 90)

    def process(self):
        # process clusters in increased intervals
        if time.time() < self.next_check_ts:
            return
        self.update_next_check_ts(self.polling_interval_min, self.polling_interval_max)

        logging.debug('checking cluster alerts')

        for cluster in self.cluster_config['clusters']:
            cluster_name = cluster['name']

            if 'appDomain' not in cluster.keys():
                continue

            if cluster.get('disableAlerts', False):
                logging.debug('alerts disabled for cluster %s', cluster_name)
                continue

            try:
                logging.debug('getting alerts for cluster %s', cluster_name)
                res = requests.get(
                    url="https://" + cluster['appDomain'] + "/prometheus/api/v1/alerts",
                    auth=('token', cluster.get('monitoringToken')),
                    timeout=5
                )
            except requests.exceptions.RequestException:
                res = None

            if not (res and res.ok):
                logging.warning('unable to get alerts from cluster %s', cluster_name)
                continue

            alert_data = res.json()
            alerts = alert_data['data']['alerts']

            logging.debug('got %d alert entries for cluster %s', len(alert_data), cluster_name)

            # the watchdog alert should be always firing
            if len(alerts) == 0:
                logging.warning('zero alerts, watchdog is not working for cluster %s', cluster_name)
                continue

            # filter out low severity ('none', 'info') and speculative alerts (state not 'firing')
            real_alerts = list(filter(
                lambda x: x['labels'].get('severity', 'none') not in ('none', 'info') and x['state'] == 'firing',
                alerts
            ))

            if 'ALERTNAMES_TO_IGNORE' in os.environ:
                alertnames_to_ignore = os.environ.get('ALERTNAMES_TO_IGNORE').split(',')
                real_alerts = list(filter(
                    lambda x: x['labels']['alertname'] not in alertnames_to_ignore,
                    real_alerts
                ))

            if len(real_alerts) > 0:
                json_data = []
                logging.info('found %d alerts for cluster %s', len(real_alerts), cluster_name)

                # add real alerts to post data
                for alert in real_alerts:
                    json_data.append(
                        dict(
                            target=cluster_name,
                            source='prometheus',
                            status='firing',
                            data=alert
                        )
                    )

                # add notification that the cluster has been polled successfully
                json_data.append(
                    dict(
                        target=cluster_name,
                        source='prometheus',
                        status='ok',
                        data=dict()
                    )
                )
                res = self.client.do_post(
                    object_url='alerts',
                    json_data=json_data
                )
            else:
                # inform API that cluster is ok and archive any firing alerts
                res = self.client.do_post(
                    object_url='alert_reset/%s/%s' % (cluster_name, 'prometheus'),
                    json_data=None)

            if not res.ok:
                logging.warning('unable to update alerts in api, code/reason: %s/%s', res.status_code, res.reason)


class WorkspaceController(ControllerBase):
    """
    Controller that takes care of Workspace tasks
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.polling_interval_min, self.polling_interval_max = self.get_polling_interval(30, 90)
        self.max_concurrent_tasks = 20

    def process(self):
        # process workspace management in increased intervals
        if time.time() < self.next_check_ts:
            return
        self.update_next_check_ts(self.polling_interval_min, self.polling_interval_max)

        # Try to obtain a global lock for WorkspaceTaskProcessing.
        # Should we lose the race, the winner takes it, and we try next time we are active
        lock = self.client.obtain_lock(WS_CONTROLLER_TASK_LOCK_NAME, self.worker_id)
        if lock is None:
            logging.debug('WorkspaceController did not acquire lock, skipping')
            return

        try:
            logging.debug('WorkspaceController: checking tasks')
            unfinished_tasks = self.client.get_tasks(unfinished=1)

            unfinished_tasks = sorted(
                unfinished_tasks,
                key=lambda t: '%d-%s' % (t.get('create_ts'), t.get('id')),
                reverse=True
            )
            tasks = unfinished_tasks[:self.max_concurrent_tasks]

            for task in tasks:
                logging.debug(task)
                # process tasks and release the lock
                try:
                    if task.get('kind') == Task.KIND_WORKSPACE_VOLUME_BACKUP:
                        self.process_volume_backup_task(task)
                    elif task.get('kind') == Task.KIND_WORKSPACE_VOLUME_RESTORE:
                        self.process_volume_restore_task(task)
                    else:
                        logging.warning('unknown task kind: %s' % task.kind)
                except Exception as e:
                    logging.warning('Marking task %s FAILED due to "%s"', task.get('id'), e)
                    self.client.update_task(task.get('id'), state=Task.STATE_FAILED)
                    self.client.add_task_results(
                        task.get('id'),
                        results='\n'.join(e.__str__().splitlines()[:4])
                    )
        finally:
            self.client.release_lock(WS_CONTROLLER_TASK_LOCK_NAME, self.worker_id)

    @staticmethod
    def get_volume_name(task_data):
        if task_data.get('type') == 'shared-data':
            return 'pvc-ws-vol-1'
        elif task_data.get('type') == 'user-data':
            return 'pvc-%s-work' % task_data.get('pseudonym')
        else:
            raise RuntimeWarning('Unknown task type "%s" encountered' % task_data.get('type'))

    def process_volume_backup_task(self, task):
        driver = self.get_driver(task.get('data').get('cluster'))
        if not driver:
            raise RuntimeError(
                'No driver for cluster %s in task %s' % (task.get('data').get('cluster'), task.get('id')))
        if task.get('state') == Task.STATE_NEW:
            logging.info('Starting processing of task %s', task.get('id'))
            self.client.update_task(task.get('id'), state=Task.STATE_PROCESSING)

            driver.create_volume_backup_job(
                self.client.token,
                task.get('data').get('workspace_id'),
                self.get_volume_name(task.get('data')),
            )
        elif task.get('state') == Task.STATE_PROCESSING:
            if driver.check_volume_backup_job(
                    self.client.token,
                    task.get('data').get('workspace_id'),
                    self.get_volume_name(task.get('data')),
            ):
                logging.info('Task %s FINISHED', task.get('id'))
                self.client.update_task(task.get('id'), state=Task.STATE_FINISHED)
        else:
            logging.warning(
                'task %s in state %s should not end up in processing', task.get('id'), task.get('state'))

    def process_volume_restore_task(self, task):
        task_data = task.get('data')
        driver = self.get_driver(task_data.get('tgt_cluster'))
        if not driver:
            raise RuntimeError(
                'No driver for tgt_cluster %s in task %s' % (task_data.get('tgt_cluster'), task.get('id')))
        ws_id = task_data.get('workspace_id')
        if not ws_id:
            raise RuntimeError('No data.workspace_id in task %s' % task.get('id'))
        src_cluster = task_data.get('src_cluster')
        if not src_cluster:
            raise RuntimeError('No data.src_cluster in task %s' % task.get('id'))

        if task.get('state') == Task.STATE_NEW:
            logging.info('Starting processing of task %s', task.get('id'))
            ws = self.client.get_workspace(ws_id)
            self.client.update_task(task.get('id'), state=Task.STATE_PROCESSING)

            # figure out right size for user work volume
            if task_data.get('type') == 'shared-data':
                volume_size_gib = ws.get('config', {}).get('shared_folder_size_gib', 20)
                storage_class_name = driver.cluster_config.get('storageClassNameShared')
            elif task_data.get('type') == 'user-data':
                volume_size_gib = ws.get('config', {}).get('user_work_folder_size_gib', 1)
                storage_class_name = driver.cluster_config.get('storageClassNameUser')
            else:
                raise RuntimeWarning('Unknown task type "%s" encountered' % task_data.get('type'))

            driver.create_volume_restore_job(
                token=self.client.token,
                workspace_id=ws_id,
                volume_name=self.get_volume_name(task_data),
                volume_size_spec='%dGi' % volume_size_gib,
                storage_class=storage_class_name,
                src_cluster=src_cluster,
            )
        elif task.get('state') == Task.STATE_PROCESSING:
            if driver.check_volume_restore_job(self.client.token, ws_id, self.get_volume_name(task_data)):
                logging.info('Task %s FINISHED', task.get('id'))
                self.client.update_task(task.get('id'), state=Task.STATE_FINISHED)
        else:
            logging.warning(
                'task %s in state %s should not end up in processing', task.get('id'), task.get('state'))