atlassian/themis

View on GitHub
themis/util/aws_common.py

Summary

Maintainability
B
4 hrs
Test Coverage
import re
import json
import logging
import boto3
import os
import pytz
from datetime import datetime
from themis import config, constants
from themis.config import SECTION_EMR
from themis.util.common import run_func, remove_lines_from_string, get_logger, short_uid, is_ip_address
from themis.util.common import CURL_CONNECT_TIMEOUT, STATIC_INFO_CACHE_TIMEOUT, QUERY_CACHE_TIMEOUT
from themis.util.remote import run_ssh

# constants

PRESTO_STATE_SHUTTING_DOWN = 'SHUTTING_DOWN'
PRESTO_STATE_ACTIVE = 'ACTIVE'

INSTANCE_GROUP_TYPE_MASTER = 'MASTER'
INSTANCE_GROUP_TYPE_CORE = 'CORE'
INSTANCE_GROUP_TYPE_TASK = 'TASK'

INSTANCE_STATE_RUNNING = 'RUNNING'
INSTANCE_STATE_TERMINATED = 'TERMINATED'

INVALID_CONFIG_VALUE = '_invalid_value_to_avoid_restart_'

CW_TIMESTAMP_FORMAT = '%Y-%m-%dT%H:%M:%SZ'

# TODO move to model
CLUSTER_TYPE_PRESTO = 'Presto'
CLUSTER_TYPE_HIVE = 'Hive'

# Define local test endpoints
TEST_ENDPOINTS = {}

# Keep a reference to the initial boto3 session with default credentials
INITIAL_BOTO3_SESSION = None

# Map role ARNs to sessions with assumed roles
ASSUMED_ROLE_SESSIONS = {}

# Timespan before expiry to renew session
SESSION_EXPIRY_RENEW_PERIOD = 10 * 60

# logger
LOG = get_logger(__name__)


class StsSession(object):
    def __init__(self, role_arn):
        self.session = None
        self.role_arn = role_arn

    def get(self):
        if not self.role_arn:
            # work with default boto3 session
            return boto3
        if not self.session or self.expires_soon():
            # create a new boto3 session using assume-role
            self.session = do_assume_role(self.role_arn)
        return self.session

    def expires_soon(self):
        return session_expires_soon(self.session)


def session_expires_soon(session):
    exp = session.expiration
    now = datetime.now()
    # make dates substractable
    exp = exp.replace(tzinfo=None)
    now = now.replace(tzinfo=None)
    # get delta
    delta = (exp - now).total_seconds()
    return delta < SESSION_EXPIRY_RENEW_PERIOD


def init_aws_cli():
    endpoint_url = os.environ.AWS_ENDPOINT_URL
    if endpoint_url:
        TEST_ENDPOINTS['emr'] = endpoint_url
        TEST_ENDPOINTS['kinesis'] = endpoint_url
        TEST_ENDPOINTS['cloudwatch'] = endpoint_url
        TEST_ENDPOINTS['ec2'] = endpoint_url


def connect_emr(session=None, role=None):
    return connect_to_service('emr', session=session, role=role)


def connect_kinesis(session=None, role=None):
    return connect_to_service('kinesis', session=session, role=role)


def connect_cloudwatch(session=None, role=None):
    return connect_to_service('cloudwatch', session=session, role=role)


def connect_ec2(session=None, role=None):
    return connect_to_service('ec2', session=session, role=role)


def connect_to_service(service, session=None, role=None):
    session = session if session else assume_role(role) if role else None
    endpoint_url = TEST_ENDPOINTS.get(service)
    boto3_session = session.get() if session else boto3
    return boto3_session.client(service, endpoint_url=endpoint_url)


def assume_role(role_arn):
    return StsSession(role_arn)


def do_assume_role(role_arn):
    session = ASSUMED_ROLE_SESSIONS.get(role_arn)
    if not session or session_expires_soon(session):
        global INITIAL_BOTO3_SESSION
        # save initial session with micros credentials to use for assuming role
        if not INITIAL_BOTO3_SESSION:
            INITIAL_BOTO3_SESSION = boto3.session.Session()
        client = INITIAL_BOTO3_SESSION.client('sts')
        # generate a random role session name
        role_session_name = 's-' + str(short_uid())
        # make API call to assume role
        response = client.assume_role(RoleArn=role_arn, RoleSessionName=role_session_name)
        # save session with new credentials to use for all aws calls
        session = boto3.session.Session(
            aws_access_key_id=response['Credentials']['AccessKeyId'],
            aws_secret_access_key=response['Credentials']['SecretAccessKey'],
            aws_session_token=response['Credentials']['SessionToken'])
        session.expiration = response['Credentials']['Expiration']
        # store session for future reference
        ASSUMED_ROLE_SESSIONS[role_arn] = session
    return session


def ip_to_hostname(ip, domain_name):
    return 'ip-' + re.sub(r'\.', r'-', ip) + '.' + domain_name


def hostname_to_ip(host):
    return re.sub(r'ip-([0-9]+)-([0-9]+)-([0-9]+)-([0-9]+)\..+', r'\1.\2.\3.\4', host)


def parse_cloudwatch_timestamp(timestamp, tzinfo=pytz.UTC):
    time = datetime.strptime(timestamp, CW_TIMESTAMP_FORMAT)
    time = time.replace(tzinfo=tzinfo)
    return time


def format_cloudwatch_timestamp(timestamp):
    return timestamp.strftime(CW_TIMESTAMP_FORMAT)


def get_instance_by_ip(ip, role=None):
    ec2_client = connect_ec2(role=role)
    result = run_func(ec2_client.describe_instances,
                      Filters=[{'Name': 'private-ip-address', 'Values': [ip]}],
                      cache_duration_secs=STATIC_INFO_CACHE_TIMEOUT)
    result = json.loads(result)
    return result['Reservations'][0]['Instances'][0]


def get_instance_groups(cluster_id, role=None):
    emr_client = connect_emr(role=role)
    result = run_func(emr_client.list_instance_groups, ClusterId=cluster_id,
                      cache_duration_secs=QUERY_CACHE_TIMEOUT)
    result = json.loads(result)
    result_map = {}
    for group in result['InstanceGroups']:
        group_type = group['InstanceGroupType']
        if group_type not in result_map:
            result_map[group_type] = []
        group['id'] = group['Id']
        group['type'] = group_type
        result_map[group_type].append(group)

    return result_map


def get_instance_groups_tasknodes(cluster_id, role=None):
    return get_instance_groups(cluster_id, role=role)[INSTANCE_GROUP_TYPE_TASK]


def get_instance_groups_nodes(cluster_id, role=None, instance_group_type=INSTANCE_GROUP_TYPE_CORE):
    return get_instance_groups(cluster_id, role=role)[instance_group_type]


def get_instance_groups_ids(cluster_id, group_type, role=None):
    groups = get_instance_groups(cluster_id, role=role)
    return groups[group_type]


def get_instance_group_type(cluster_id, group_id, role=None):
    return get_instance_group_details(cluster_id, group_id, role=role)['type']


def get_instance_group_details(cluster_id, group_id, role=None):
    groups = get_instance_groups(cluster_id, role=role)
    for key, arr in groups.iteritems():
        for val in arr:
            if val['id'] == group_id:
                return val
    return None


def get_instance_group_for_node(cluster_id, node_host, role=None):
    nodes = get_cluster_nodes(cluster_id, role=role)
    for node in nodes:
        if node['host'] == node_host or node['ip'] == node_host:
            return node['gid']
    return None


def get_cluster_nodes(cluster_id, role=None):
    emr_client = connect_emr(role=role)
    result = run_func(emr_client.list_instances, ClusterId=cluster_id,
                      InstanceStates=['AWAITING_FULFILLMENT', 'PROVISIONING', 'BOOTSTRAPPING', 'RUNNING'],
                      cache_duration_secs=QUERY_CACHE_TIMEOUT)
    result = json.loads(result)
    result = result['Instances']

    # read domain name config
    custom_dn = config.get_value(constants.KEY_CUSTOM_DOMAIN_NAME, section=SECTION_EMR, resource=cluster_id)

    i = 0
    while i < len(result):
        inst = result[i]
        if inst['Status']['State'] == INSTANCE_STATE_TERMINATED:
            del result[i]
            i -= 1
        else:
            inst['cid'] = inst['Id'] if 'Id' in inst else 'n/a'
            inst['iid'] = inst['Ec2InstanceId'] if 'Ec2InstanceId' in inst else 'n/a'
            inst['gid'] = inst['InstanceGroupId'] if 'InstanceGroupId' in inst else 'n/a'
            inst['ip'] = inst['PrivateIpAddress'] if 'PrivateIpAddress' in inst else 'n/a'
            inst['host'] = inst['PrivateDnsName'] if 'PrivateDnsName' in inst else 'n/a'
            if custom_dn:
                inst['host'] = ip_to_hostname(hostname_to_ip(inst['host']), custom_dn)
            inst['type'] = get_instance_group_type(cluster_id, inst['InstanceGroupId'], role=role)
            inst['state'] = inst['Status']['State']
            inst['market'] = get_instance_group_details(cluster_id, inst['InstanceGroupId'], role=role)['Market']
        i += 1
    return result


def get_all_task_nodes(cluster_id, cluster_ip, role=None):
    all_nodes = get_cluster_nodes(cluster_id, role=role)
    task_nodes = [n for n in all_nodes if n['type'] == INSTANCE_GROUP_TYPE_TASK]
    return task_nodes


def terminate_task_node(instance_group_id, instance_id, role=None):
    # terminate instance
    emr_client = connect_emr(role=role)
    LOG.info('Terminate instance %s of instance group %s' % (instance_id, instance_group_id))
    result = emr_client.modify_instance_groups(InstanceGroups=[
        {'InstanceGroupId': instance_group_id, 'EC2InstanceIdsToTerminate': [instance_id]}
    ])
    return result


def spawn_task_node(instance_group_id, current_size, additional_nodes=1, role=None):
    # start new instance
    emr_client = connect_emr(role=role)
    new_size = current_size + additional_nodes
    LOG.info('Increase instances of instance group %s from %s to %s' %
             (instance_group_id, current_size, new_size))
    result = emr_client.modify_instance_groups(InstanceGroups=[
        {'InstanceGroupId': instance_group_id, 'InstanceCount': new_size}
    ])
    return result


def is_presto_cluster(cluster):
    return cluster.type == 'Presto'


def terminate_inactive_nodes(cluster, cluster_state, role=None):
    if not is_presto_cluster(cluster):
        return
    nodes = cluster_state['nodes']
    for key, node in nodes.iteritems():
        if (node['state'] == INSTANCE_STATE_RUNNING and
                node['presto_state'] not in [PRESTO_STATE_SHUTTING_DOWN, PRESTO_STATE_ACTIVE]):
            try:
                cmd = "cat /etc/presto/conf/config.properties | grep http-server.threads"
                out = run_ssh(cmd, cluster.ip, user='hadoop',
                              via_hosts=[node['host']], cache_duration_secs=QUERY_CACHE_TIMEOUT)
                if INVALID_CONFIG_VALUE in out:
                    LOG.info("Terminating instance of idle node %s in instance group %s" %
                             (node['iid'], node['gid']))
                    terminate_task_node(instance_group_id=node['gid'], instance_id=node['iid'], role=role)
            except Exception, e:
                LOG.info("Unable to read Presto config from node %s: %s" % (node, e))


def get_presto_node_state(cluster_ip, node_ip):
    cmd = ('curl -s --connect-timeout %s --max-time %s http://%s:8889/v1/info/state' %
           (CURL_CONNECT_TIMEOUT, CURL_CONNECT_TIMEOUT, node_ip))
    out = run_ssh(cmd, cluster_ip, user='hadoop', cache_duration_secs=QUERY_CACHE_TIMEOUT)
    out = remove_lines_from_string(out, r'.*Permanently added.*')
    out = re.sub(r'\s*"(.+)"\s*', r'\1', out)
    return out


def set_presto_node_state(cluster_ip, node_ip, state):
    cmd = ("sudo sed -i 's/http-server.threads.max=.*/http-server.threads.max=%s/g' " +
           "/etc/presto/conf/config.properties") % INVALID_CONFIG_VALUE
    if not is_ip_address(cluster_ip):
        cluster_ip = hostname_to_ip(cluster_ip)
    out = run_ssh(cmd, cluster_ip, user='hadoop', via_hosts=[node_ip], cache_duration_secs=QUERY_CACHE_TIMEOUT)

    cmd = ("curl -s --connect-timeout %s -X PUT -H 'Content-Type:application/json' " +
           "-d '\\\"%s\\\"' http://%s:8889/v1/info/state") % (CURL_CONNECT_TIMEOUT, state, node_ip)
    LOG.info(cmd)
    out = run_ssh(cmd, cluster_ip, user='hadoop', cache_duration_secs=QUERY_CACHE_TIMEOUT)
    out = re.sub(r'\s*"(.+)"\s*', r'\1', out)
    return out