atlassian/localstack

View on GitHub
localstack/utils/aws/aws_stack.py

Summary

Maintainability
D
1 day
Test Coverage
import os
import boto3
import json
import base64
import logging
import re
from six import iteritems
from threading import Timer
from localstack import config
from localstack.constants import *
from localstack.utils.common import *
from localstack.utils.aws.aws_models import *

# AWS environment variable names
ENV_ACCESS_KEY = 'AWS_ACCESS_KEY_ID'
ENV_SECRET_KEY = 'AWS_SECRET_ACCESS_KEY'
ENV_SESSION_TOKEN = 'AWS_SESSION_TOKEN'

# set up logger
LOGGER = logging.getLogger(__name__)

# Use this field if you want to provide a custom boto3 session.
# This field takes priority over CREATE_NEW_SESSION_PER_BOTO3_CONNECTION
CUSTOM_BOTO3_SESSION = None
# Use this flag to enable creation of a new session for each boto3 connection.
# This flag will be ignored if CUSTOM_BOTO3_SESSION is specified
CREATE_NEW_SESSION_PER_BOTO3_CONNECTION = False

# Used in AWS assume role function
INITIAL_BOTO3_SESSION = None

# Assume role loop seconds
DEFAULT_TIMER_LOOP_SECONDS = 60 * 50


class Environment(object):
    def __init__(self, region=None, prefix=None):
        # target is the runtime environment to use, e.g.,
        # 'local' for local mode
        self.region = region or DEFAULT_REGION
        # prefix can be 'prod', 'stg', 'uat-1', etc.
        self.prefix = prefix

    def apply_json(self, j):
        if isinstance(j, str):
            j = json.loads(j)
        self.__dict__.update(j)

    @staticmethod
    def from_string(s):
        parts = s.split(':')
        if len(parts) == 1:
            if s in PREDEFINED_ENVIRONMENTS:
                return PREDEFINED_ENVIRONMENTS[s]
            parts = [DEFAULT_REGION, s]
        if len(parts) > 2:
            raise Exception('Invalid environment string "%s"' % s)
        region = parts[0]
        prefix = parts[1]
        return Environment(region=region, prefix=prefix)

    @staticmethod
    def from_json(j):
        if not isinstance(j, dict):
            j = j.to_dict()
        result = Environment()
        result.apply_json(j)
        return result

    def __str__(self):
        return '%s:%s' % (self.region, self.prefix)


PREDEFINED_ENVIRONMENTS = {
    ENV_DEV: Environment(region=REGION_LOCAL, prefix=ENV_DEV)
}


def get_environment(env=None, region_name=None):
    """
    Return an Environment object based on the input arguments.

    Parameter `env` can be either of:
        * None (or empty), in which case the rules below are applied to (env = os.environ['ENV'] or ENV_DEV)
        * an Environment object (then this object is returned)
        * a string '<region>:<name>', which corresponds to Environment(region='<region>', prefix='<prefix>')
        * the predefined string 'dev' (ENV_DEV), which implies Environment(region='local', prefix='dev')
        * a string '<name>', which implies Environment(region=DEFAULT_REGION, prefix='<name>')

    Additionally, parameter `region_name` can be used to override DEFAULT_REGION.
    """
    if not env:
        if 'ENV' in os.environ:
            env = os.environ['ENV']
        else:
            env = ENV_DEV
    elif not is_string(env) and not isinstance(env, Environment):
        raise Exception('Invalid environment: %s' % env)

    if is_string(env):
        env = Environment.from_string(env)
    if region_name:
        env.region = region_name
    if not env.region:
        raise Exception('Invalid region in environment: "%s"' % env)
    return env


def connect_to_resource(service_name, env=None, region_name=None, endpoint_url=None):
    """
    Generic method to obtain an AWS service resource using boto3, based on environment, region, or custom endpoint_url.
    """
    return connect_to_service(service_name, client=False, env=env, region_name=region_name, endpoint_url=endpoint_url)


def get_boto3_credentials():
    if CUSTOM_BOTO3_SESSION:
        return CUSTOM_BOTO3_SESSION.get_credentials()
    return boto3.session.Session().get_credentials()


def get_boto3_session():
    my_session = None
    if CUSTOM_BOTO3_SESSION:
        return CUSTOM_BOTO3_SESSION
    if CREATE_NEW_SESSION_PER_BOTO3_CONNECTION:
        return boto3.session.Session()
    # return default session
    return boto3


def get_local_service_url(service_name):
    if service_name == 's3api':
        service_name = 's3'
    return os.environ['TEST_%s_URL' % (service_name.upper().replace('-', '_'))]


def connect_to_service(service_name, client=True, env=None, region_name=None, endpoint_url=None):
    """
    Generic method to obtain an AWS service client using boto3, based on environment, region, or custom endpoint_url.
    """
    env = get_environment(env, region_name=region_name)
    my_session = get_boto3_session()
    method = my_session.client if client else my_session.resource
    verify = True
    if not endpoint_url:
        if env.region == REGION_LOCAL:
            endpoint_url = get_local_service_url(service_name)
            verify = False
    region = env.region if env.region != REGION_LOCAL else DEFAULT_REGION
    return method(service_name, region_name=region, endpoint_url=endpoint_url, verify=verify)


class VelocityInput:
    """Simple class to mimick the behavior of variable '$input' in AWS API Gateway integration velocity templates.
    See: http://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html"""
    def __init__(self, value):
        self.value = value

    def path(self, path):
        from jsonpath_rw import parse
        value = self.value if isinstance(self.value, dict) else json.loads(self.value)
        jsonpath_expr = parse(path)
        result = [match.value for match in jsonpath_expr.find(value)]
        result = result[0] if len(result) == 1 else result
        return result

    def json(self, path):
        return json.dumps(self.path(path))


class VelocityUtil:
    """Simple class to mimick the behavior of variable '$util' in AWS API Gateway integration velocity templates.
    See: http://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html"""
    def base64Encode(self, s):
        if not isinstance(s, str):
            s = json.dumps(s)
        encoded_str = s.encode(config.DEFAULT_ENCODING)
        encoded_b64_str = base64.b64encode(encoded_str)
        return encoded_b64_str.decode(config.DEFAULT_ENCODING)

    def base64Decode(self, s):
        if not isinstance(s, str):
            s = json.dumps(s)
        return base64.b64decode(s)


def render_velocity_template(template, context, as_json=False):
    import airspeed
    t = airspeed.Template(template)
    variables = {
        'input': VelocityInput(context),
        'util': VelocityUtil()
    }
    replaced = t.merge(variables)
    if as_json:
        replaced = json.loads(replaced)
    return replaced


def get_account_id(account_id=None, env=None):
    if account_id:
        return account_id
    env = get_environment(env)
    if env.region == REGION_LOCAL:
        return os.environ['TEST_AWS_ACCOUNT_ID']
    raise Exception('Unable to determine AWS account ID')


def role_arn(role_name, account_id=None, env=None):
    env = get_environment(env)
    account_id = get_account_id(account_id, env=env)
    return "arn:aws:iam::%s:role/%s" % (account_id, role_name)


def iam_resource_arn(resource, role=None, env=None):
    env = get_environment(env)
    if not role:
        role = get_iam_role(resource, env=env)
    return role_arn(role_name=role, account_id=get_account_id())


def get_iam_role(resource, env=None):
    env = get_environment(env)
    return 'role-%s' % resource


def dynamodb_table_arn(table_name, account_id=None):
    account_id = get_account_id(account_id)
    return "arn:aws:dynamodb:%s:%s:table/%s" % (DEFAULT_REGION, account_id, table_name)


def dynamodb_stream_arn(table_name, account_id=None):
    account_id = get_account_id(account_id)
    return ("arn:aws:dynamodb:%s:%s:table/%s/stream/%s" %
        (DEFAULT_REGION, account_id, table_name, timestamp()))


def lambda_function_arn(function_name, account_id=None):
    pattern = 'arn:aws:lambda:.*:.*:function:.*'
    if re.match(pattern, function_name):
        return function_name
    if len(function_name.split(':')) > 1:
        raise Exception('Lambda function name should not contain a colon ":"')
    account_id = get_account_id(account_id)
    return pattern.replace('.*', '%s') % (DEFAULT_REGION, account_id, function_name)


def cognito_user_pool_arn(user_pool_id, account_id=None):
    account_id = get_account_id(account_id)
    return 'arn:aws:cognito-idp:%s:%s:userpool/%s' % (DEFAULT_REGION, account_id, user_pool_id)


def kinesis_stream_arn(stream_name, account_id=None):
    account_id = get_account_id(account_id)
    return "arn:aws:kinesis:%s:%s:stream/%s" % (DEFAULT_REGION, account_id, stream_name)


def firehose_stream_arn(stream_name, account_id=None):
    account_id = get_account_id(account_id)
    return ("arn:aws:firehose:%s:%s:deliverystream/%s" % (DEFAULT_REGION, account_id, stream_name))


def s3_bucket_arn(bucket_name, account_id=None):
    return "arn:aws:s3:::%s" % (bucket_name)


def sqs_queue_arn(queue_name, account_id=None):
    account_id = get_account_id(account_id)
    return ("arn:aws:sqs:%s:%s:%s" % (DEFAULT_REGION, account_id, queue_name))


def sns_topic_arn(topic_name, account_id=None):
    account_id = get_account_id(account_id)
    return ("arn:aws:sns:%s:%s:%s" % (DEFAULT_REGION, account_id, topic_name))


def get_sqs_queue_url(queue_name):
    client = connect_to_service('sqs')
    response = client.get_queue_url(QueueName=queue_name)
    return response['QueueUrl']


def dynamodb_get_item_raw(request):
    headers = mock_aws_request_headers()
    headers['X-Amz-Target'] = 'DynamoDB_20120810.GetItem'
    new_item = make_http_request(url=config.TEST_DYNAMODB_URL,
        method='POST', data=json.dumps(request), headers=headers)
    new_item = json.loads(new_item.text)
    return new_item


def mock_aws_request_headers(service='dynamodb'):
    ctype = APPLICATION_AMZ_JSON_1_0
    if service == 'kinesis':
        ctype = APPLICATION_AMZ_JSON_1_1
    access_key = get_boto3_credentials().access_key
    headers = {
        'Content-Type': ctype,
        'Accept-Encoding': 'identity',
        'X-Amz-Date': '20160623T103251Z',
        'Authorization': ('AWS4-HMAC-SHA256 ' +
            'Credential=%s/20160623/us-east-1/%s/aws4_request, ' +
            'SignedHeaders=content-type;host;x-amz-date;x-amz-target, Signature=1234') % (access_key, service)
    }
    return headers


def get_apigateway_integration(api_id, method, path, env=None):
    apigateway = connect_to_service(service_name='apigateway', client=True, env=env)

    resources = apigateway.get_resources(restApiId=api_id, limit=100)
    resource_id = None
    for r in resources['items']:
        if r['path'] == path:
            resource_id = r['id']
    if not resource_id:
        raise Exception('Unable to find apigateway integration for path "%s"' % path)

    integration = apigateway.get_integration(
        restApiId=api_id, resourceId=resource_id, httpMethod=method
    )
    return integration


def get_apigateway_resource_for_path(api_id, path, parent=None, resources=None):
    if resources is None:
        apigateway = connect_to_service(service_name='apigateway')
        resources = apigateway.get_resources(restApiId=api_id, limit=100)
    if not isinstance(path, list):
        path = path.split('/')
    if not path:
        return parent
    for resource in resources:
        if resource['pathPart'] == path[0] and (not parent or parent['id'] == resource['parentId']):
            return get_apigateway_resource_for_path(api_id, path[1:], parent=resource, resources=resources)
    return None


def get_apigateway_path_for_resource(api_id, resource_id, path_suffix='', resources=None):
    if resources is None:
        apigateway = connect_to_service(service_name='apigateway')
        resources = apigateway.get_resources(restApiId=api_id, limit=100)
    target_resource = list(filter(lambda res: res['id'] == resource_id, resources))[0]
    path_part = target_resource.get('pathPart', '')
    if path_suffix:
        if path_part:
            path_suffix = '%s/%s' % (path_part, path_suffix)
    else:
        path_suffix = path_part
    parent_id = target_resource.get('parentId')
    if not parent_id:
        return '/%s' % path_suffix
    return get_apigateway_path_for_resource(api_id, parent_id, path_suffix=path_suffix, resources=resources)


def create_api_gateway(name, description=None, resources=None, stage_name=None,
        enabled_api_keys=[], env=None, usage_plan_name=None):
    client = connect_to_service('apigateway', env=env)
    if not resources:
        resources = []
    if not stage_name:
        stage_name = 'testing'
    if not usage_plan_name:
        usage_plan_name = 'Basic Usage'
    if not description:
        description = 'Test description for API "%s"' % name

    LOGGER.info('Creating API resources under API Gateway "%s".' % name)
    api = client.create_rest_api(name=name, description=description)
    # list resources
    api_id = api['id']
    resources_list = client.get_resources(restApiId=api_id)
    root_res_id = resources_list['items'][0]['id']
    # add API resources and methods
    for path, methods in iteritems(resources):
        if '/' in path:
            raise Exception('Currently only works for root-level resources.')
        api_resource = client.create_resource(restApiId=api_id, parentId=root_res_id, pathPart=path)
        # add methods to the API resource
        for method in methods:
            api_method = client.put_method(
                restApiId=api_id,
                resourceId=api_resource['id'],
                httpMethod=method['httpMethod'],
                authorizationType=method.get('authorizationType') or 'NONE',
                apiKeyRequired=method.get('apiKeyRequired') or False
            )
            # create integrations for this API resource/method
            integrations = method['integrations']
            create_api_gateway_integrations(api_id, api_resource['id'], method, integrations, env=env)
    # deploy the API gateway
    api_deployed = client.create_deployment(restApiId=api_id, stageName=stage_name)
    return api


def create_api_gateway_integrations(api_id, resource_id, method, integrations=[], env=None):
    client = connect_to_service('apigateway', env=env)
    for integration in integrations:
        req_templates = integration.get('requestTemplates') or {}
        res_templates = integration.get('responseTemplates') or {}
        success_code = integration.get('successCode') or '200'
        client_error_code = integration.get('clientErrorCode') or '400'
        server_error_code = integration.get('serverErrorCode') or '500'
        # create integration
        response = client.put_integration(
            restApiId=api_id,
            resourceId=resource_id,
            httpMethod=method['httpMethod'],
            integrationHttpMethod=method.get('integrationHttpMethod') or method['httpMethod'],
            type=integration['type'],
            uri=integration['uri'],
            requestTemplates=req_templates
        )
        response_configs = [
            {'pattern': '^2.*', 'code': success_code, 'res_templates': res_templates},
            {'pattern': '^4.*', 'code': client_error_code, 'res_templates': {}},
            {'pattern': '^5.*', 'code': server_error_code, 'res_templates': {}}
        ]
        # create response configs
        for response_config in response_configs:
            # create integration response
            response = client.put_integration_response(
                restApiId=api_id,
                resourceId=resource_id,
                httpMethod=method['httpMethod'],
                statusCode=response_config['code'],
                responseTemplates=response_config['res_templates'],
                selectionPattern=response_config['pattern']
            )
            # create method response
            response = client.put_method_response(
                restApiId=api_id,
                resourceId=resource_id,
                httpMethod=method['httpMethod'],
                statusCode=response_config['code']
            )


def get_elasticsearch_endpoint(domain=None, region_name=None):
    env = get_environment(region_name=region_name)
    if env.region == REGION_LOCAL:
        return os.environ['TEST_ELASTICSEARCH_URL']
    # get endpoint from API
    es_client = connect_to_service(service_name='es', region_name=env.region)
    info = es_client.describe_elasticsearch_domain(DomainName=domain)
    endpoint = 'https://%s' % info['DomainStatus']['Endpoint']
    return endpoint


def connect_elasticsearch(endpoint=None, domain=None, region_name=None, env=None):
    from elasticsearch import Elasticsearch, RequestsHttpConnection
    from requests_aws4auth import AWS4Auth

    env = get_environment(env, region_name=region_name)
    verify_certs = False
    use_ssl = False
    if not endpoint and env.region == REGION_LOCAL:
        endpoint = os.environ['TEST_ELASTICSEARCH_URL']
    if not endpoint and env.region != REGION_LOCAL and domain:
        endpoint = get_elasticsearch_endpoint(domain=domain, region_name=env.region)
    # use ssl?
    if 'https://' in endpoint:
        use_ssl = True
        if env.region != REGION_LOCAL:
            verify_certs = True

    if CUSTOM_BOTO3_SESSION or (ENV_ACCESS_KEY in os.environ and ENV_SECRET_KEY in os.environ):
        access_key = os.environ.get(ENV_ACCESS_KEY)
        secret_key = os.environ.get(ENV_SECRET_KEY)
        session_token = os.environ.get(ENV_SESSION_TOKEN)
        if CUSTOM_BOTO3_SESSION:
            credentials = CUSTOM_BOTO3_SESSION.get_credentials()
            access_key = credentials.access_key
            secret_key = credentials.secret_key
            session_token = credentials.token
        awsauth = AWS4Auth(access_key, secret_key, env.region, 'es', session_token=session_token)
        connection_class = RequestsHttpConnection
        return Elasticsearch(hosts=[endpoint], verify_certs=verify_certs, use_ssl=use_ssl,
                             connection_class=connection_class, http_auth=awsauth)
    return Elasticsearch(hosts=[endpoint], verify_certs=verify_certs, use_ssl=use_ssl)


def create_kinesis_stream(stream_name, shards=1, env=None, delete=False):
    env = get_environment(env)
    # stream
    stream = KinesisStream(id=stream_name, num_shards=shards)
    conn = connect_to_service('kinesis', env=env)
    stream.connect(conn)
    if delete:
        run_safe(lambda: stream.destroy(), print_error=False)
    stream.create()
    stream.wait_for()
    return stream


def kinesis_get_latest_records(stream_name, shard_id, count=10, env=None):
    kinesis = connect_to_service('kinesis', env=env)
    result = []
    response = kinesis.get_shard_iterator(StreamName=stream_name, ShardId=shard_id,
        ShardIteratorType='TRIM_HORIZON')
    shard_iterator = response['ShardIterator']
    while shard_iterator:
        records_response = kinesis.get_records(ShardIterator=shard_iterator)
        records = records_response['Records']
        for record in records:
            try:
                record['Data'] = to_str(record['Data'])
            except Exception as e:
                pass
        result.extend(records)
        shard_iterator = records_response['NextShardIterator'] if records else False
        while len(result) > count:
            result.pop(0)
    return result