localstack/services/sns/sns_listener.py
import json
import logging
import requests
import xmltodict
from requests.models import Response
from six.moves.urllib import parse as urlparse
from localstack.services.awslambda import lambda_api
from localstack.utils.aws import aws_stack
from localstack.utils.common import short_uid
# mappings for SNS topic subscriptions
SNS_SUBSCRIPTIONS = {}
# set up logger
LOGGER = logging.getLogger(__name__)
def update_sns(method, path, data, headers, response=None, return_forward_info=False):
if return_forward_info:
if method == 'POST' and path == '/':
req_data = urlparse.parse_qs(data)
req_action = req_data['Action'][0]
topic_arn = req_data.get('TargetArn') or req_data.get('TopicArn')
if topic_arn:
topic_arn = topic_arn[0]
if topic_arn not in SNS_SUBSCRIPTIONS:
SNS_SUBSCRIPTIONS[topic_arn] = []
if req_action == 'SetSubscriptionAttributes':
sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
if not sub:
return make_error(message='Unable to find subscription for given ARN', code=400)
attr_name = req_data['AttributeName'][0]
attr_value = req_data['AttributeValue'][0]
sub[attr_name] = attr_value
return make_response(req_action)
elif req_action == 'GetSubscriptionAttributes':
sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
if not sub:
return make_error(message='Unable to find subscription for given ARN', code=400)
content = '<Attributes>'
for key, value in sub.items():
content += '<entry><key>%s</key><value>%s</value></entry>\n' % (key, value)
content += '</Attributes>'
return make_response(req_action, content=content)
elif req_action == 'Subscribe':
if 'Endpoint' not in req_data:
return make_error(message='Endpoint not specified in subscription', code=400)
elif req_action == 'Publish':
message = req_data['Message'][0]
sqs_client = aws_stack.connect_to_service('sqs')
for subscriber in SNS_SUBSCRIPTIONS[topic_arn]:
if subscriber['Protocol'] == 'sqs':
queue_name = subscriber['Endpoint'].split(':')[5]
queue_url = subscriber.get('sqs_queue_url')
if not queue_url:
queue_url = aws_stack.get_sqs_queue_url(queue_name)
subscriber['sqs_queue_url'] = queue_url
sqs_client.send_message(QueueUrl=queue_url,
MessageBody=create_sns_message_body(subscriber, req_data))
elif subscriber['Protocol'] == 'lambda':
lambda_api.process_sns_notification(
subscriber['Endpoint'],
topic_arn, message, subject=req_data.get('Subject')
)
elif subscriber['Protocol'] == 'http':
requests.post(
subscriber['Endpoint'],
headers={
'Content-Type': 'text/plain',
'x-amz-sns-message-type': 'Notification'
},
data=json.dumps({
'Type': 'Notification',
'Message': message,
})
)
else:
LOGGER.warning('Unexpected protocol "%s" for SNS subscription' % subscriber['Protocol'])
# return response here because we do not want the request to be forwarded to SNS
return make_response(req_action)
return True
else:
# This branch is executed by the proxy after we've already received a
# response from the backend, hence we can utilize the "reponse" variable here
if method == 'POST' and path == '/':
req_data = urlparse.parse_qs(data)
req_action = req_data['Action'][0]
if req_action == 'Subscribe' and response.status_code < 400:
response_data = xmltodict.parse(response.content)
topic_arn = (req_data.get('TargetArn') or req_data.get('TopicArn'))[0]
sub_arn = response_data['SubscribeResponse']['SubscribeResult']['SubscriptionArn']
subscription = {
# http://docs.aws.amazon.com/cli/latest/reference/sns/get-subscription-attributes.html
'TopicArn': topic_arn,
'Endpoint': req_data['Endpoint'][0],
'Protocol': req_data['Protocol'][0],
'SubscriptionArn': sub_arn,
'RawMessageDelivery': 'false'
}
SNS_SUBSCRIPTIONS[topic_arn].append(subscription)
# ---------------
# HELPER METHODS
# ---------------
def get_subscription_by_arn(sub_arn):
# TODO maintain separate map instead of traversing all items
for key, subscriptions in SNS_SUBSCRIPTIONS.items():
for sub in subscriptions:
if sub['SubscriptionArn'] == sub_arn:
return sub
def make_response(op_name, content=''):
response = Response()
if not content:
content = '<MessageId>%s</MessageId>' % short_uid()
response._content = """<{op_name}Response xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<{op_name}Result>
{content}
</{op_name}Result>
<ResponseMetadata><RequestId>{req_id}</RequestId></ResponseMetadata>
</{op_name}Response>""".format(op_name=op_name, content=content, req_id=short_uid())
response.status_code = 200
return response
def make_error(message, code=400, code_string='InvalidParameter'):
response = Response()
response._content = """<ErrorResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/"><Error>
<Type>Sender</Type>
<Code>{code_string}</Code>
<Message>{message}</Message>
</Error><RequestId>{req_id}</RequestId>
</ErrorResponse>""".format(message=message, code_string=code_string, req_id=short_uid())
response.status_code = code
return response
def create_sns_message_body(subscriber, req_data):
message = req_data['Message'][0]
subject = req_data.get('Subject', [None])[0]
if subscriber['RawMessageDelivery'] == 'true':
return message
data = {}
data['Type'] = 'Notification'
data['Message'] = message
data['TopicArn'] = subscriber['TopicArn']
if subject is not None:
data['Subject'] = subject
attributes = get_message_attributes(req_data)
if attributes:
data['MessageAttributes'] = attributes
return json.dumps(data)
def get_message_attributes(req_data):
attributes = {}
x = 1
while True:
name = req_data.get('MessageAttributes.entry.' + str(x) + ".Name", [None])[0]
if name is not None:
attribute = {}
attribute['Type'] = req_data.get('MessageAttributes.entry.' + str(x) + ".Value.DataType", [None])[0]
string_value = req_data.get('MessageAttributes.entry.' + str(x) + ".Value.StringValue", [None])[0]
binary_value = req_data.get('MessageAttributes.entry.' + str(x) + ".Value.BinaryValue", [None])[0]
if string_value is not None:
attribute['Value'] = string_value
elif binary_value is not None:
attribute['Value'] = binary_value
attributes[name] = attribute
x += 1
else:
break
return attributes