localstack/utils/common.py
from __future__ import print_function
import threading
import traceback
import os
import hashlib
import uuid
import time
import glob
import subprocess
import six
import shutil
import socket
import json
import decimal
import logging
import tempfile
import requests
from io import BytesIO
from contextlib import closing
from datetime import datetime
from six.moves.urllib.parse import urlparse
from six.moves import cStringIO as StringIO
from six import with_metaclass
from multiprocessing.dummy import Pool
from localstack.utils.compat import bytes_
from localstack.constants import *
from localstack.config import DEFAULT_ENCODING
# arrays for temporary files and resources
TMP_FILES = []
TMP_THREADS = []
# cache clean variables
CACHE_CLEAN_TIMEOUT = 60 * 5
CACHE_MAX_AGE = 60 * 60
CACHE_FILE_PATTERN = os.path.join(tempfile.gettempdir(), 'cache.*.json')
last_cache_clean_time = {'time': 0}
mutex_clean = threading.Semaphore(1)
mutex_popen = threading.Semaphore(1)
# misc. constants
TIMESTAMP_FORMAT = '%Y-%m-%dT%H:%M:%S'
TIMESTAMP_FORMAT_MILLIS = '%Y-%m-%dT%H:%M:%S.%fZ'
# set up logger
LOGGER = logging.getLogger(__name__)
# Helper class to convert JSON documents with datetime or decimals.
class CustomEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
if o % 1 > 0:
return float(o)
else:
return int(o)
if isinstance(o, datetime):
return str(o)
return super(CustomEncoder, self).default(o)
class FuncThread (threading.Thread):
def __init__(self, func, params, quiet=False):
threading.Thread.__init__(self)
self.daemon = True
self.params = params
self.func = func
self.quiet = quiet
def run(self):
try:
self.func(self.params)
except Exception as e:
if not self.quiet:
LOGGER.warning("Thread run method %s(%s) failed: %s" %
(self.func, self.params, traceback.format_exc()))
def stop(self, quiet=False):
if not quiet and not self.quiet:
LOGGER.warning("Not implemented: FuncThread.stop(..)")
class ShellCommandThread (FuncThread):
def __init__(self, cmd, params={}, outfile=None, env_vars={}, stdin=False,
quiet=True, inherit_cwd=False):
self.cmd = cmd
self.process = None
self.outfile = outfile or os.devnull
self.stdin = stdin
self.env_vars = env_vars
self.inherit_cwd = inherit_cwd
FuncThread.__init__(self, self.run_cmd, params, quiet=quiet)
def run_cmd(self, params):
def convert_line(line):
line = to_str(line)
return line.strip() + '\r\n'
try:
self.process = run(self.cmd, async=True, stdin=self.stdin, outfile=self.outfile,
env_vars=self.env_vars, inherit_cwd=self.inherit_cwd)
if self.outfile:
if self.outfile == subprocess.PIPE:
# get stdout/stderr from child process and write to parent output
for line in iter(self.process.stdout.readline, ''):
if not (line and line.strip()) and self.is_killed():
break
line = convert_line(line)
sys.stdout.write(line)
sys.stdout.flush()
for line in iter(self.process.stderr.readline, ''):
if not (line and line.strip()) and self.is_killed():
break
line = convert_line(line)
sys.stderr.write(line)
sys.stderr.flush()
self.process.wait()
else:
self.process.communicate()
except Exception as e:
if self.process and not self.quiet:
LOGGER.warning('Shell command error "%s": %s' % (e, self.cmd))
if self.process and not self.quiet and self.process.returncode != 0:
LOGGER.warning('Shell command exit code "%s": %s' % (self.process.returncode, self.cmd))
def is_killed(self):
if not self.process:
return True
# Note: Do NOT import "psutil" at the root scope, as this leads
# to problems when importing this file from our test Lambdas in Docker
# (Error: libc.musl-x86_64.so.1: cannot open shared object file)
import psutil
return not psutil.pid_exists(self.process.pid)
def stop(self, quiet=False):
# Note: Do NOT import "psutil" at the root scope, as this leads
# to problems when importing this file from our test Lambdas in Docker
# (Error: libc.musl-x86_64.so.1: cannot open shared object file)
import psutil
if not self.process:
LOGGER.warning("No process found for command '%s'" % self.cmd)
return
parent_pid = self.process.pid
try:
parent = psutil.Process(parent_pid)
for child in parent.children(recursive=True):
child.kill()
parent.kill()
self.process = None
except Exception as e:
if not quiet:
LOGGER.warning('Unable to kill process with pid %s' % pid)
def is_string(s, include_unicode=True):
if isinstance(s, str):
return True
if include_unicode and isinstance(s, six.text_type):
return True
return False
def md5(string):
m = hashlib.md5()
m.update(bytes_(string))
return m.hexdigest()
def is_port_open(port_or_url):
port = port_or_url
host = '127.0.0.1'
if isinstance(port, six.string_types):
url = urlparse(port_or_url)
port = url.port
host = url.hostname
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(1)
result = sock.connect_ex((host, port))
return result == 0
def timestamp(time=None, format=TIMESTAMP_FORMAT):
if not time:
time = datetime.utcnow()
if isinstance(time, six.integer_types + (float, )):
time = datetime.fromtimestamp(time)
return time.strftime(format)
def retry(function, retries=3, sleep=1, sleep_before=0, **kwargs):
raise_error = None
if sleep_before > 0:
time.sleep(sleep_before)
for i in range(0, retries + 1):
try:
return function(**kwargs)
except Exception as error:
raise_error = error
time.sleep(sleep)
raise raise_error
def dump_thread_info():
for t in threading.enumerate():
print(t)
print(run("ps aux | grep 'node\\|java\\|python'"))
def merge_recursive(source, destination):
for key, value in source.items():
if isinstance(value, dict):
# get node or create one
node = destination.setdefault(key, {})
merge_recursive(value, node)
else:
if not isinstance(destination, dict):
LOGGER.warning('Destination for merging %s=%s is not dict: %s' %
(key, value, destination))
destination[key] = value
return destination
def now_utc():
return mktime(datetime.utcnow())
def now():
return mktime(datetime.now())
def mktime(timestamp):
return time.mktime(timestamp.timetuple())
def mkdir(folder):
if not os.path.exists(folder):
os.makedirs(folder)
def chmod_r(path, mode):
"""Recursive chmod"""
os.chmod(path, mode)
for root, dirnames, filenames in os.walk(path):
for dirname in dirnames:
os.chmod(os.path.join(root, dirname), mode)
for filename in filenames:
os.chmod(os.path.join(root, filename), mode)
def rm_rf(path):
"""Recursively removes file/directory"""
# Make sure all files are writeable and dirs executable to remove
chmod_r(path, 0o777)
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
def cp_r(src, dst):
"""Recursively copies file/directory"""
if os.path.isfile(src):
shutil.copy(src, dst)
else:
shutil.copytree(src, dst)
def download(url, path):
"""Downloads file at url to the given path"""
r = requests.get(url, stream=True)
try:
with open(path, 'wb') as f:
for chunk in r.iter_content(2048):
f.write(chunk)
finally:
r.close()
def short_uid():
return str(uuid.uuid4())[0:8]
def json_safe(item):
return json.loads(json.dumps(item, cls=CustomEncoder))
def save_file(file, content, append=False):
mode = 'a' if append else 'w+'
if not isinstance(content, six.string_types):
mode = mode + 'b'
with open(file, mode) as f:
f.write(content)
f.flush()
def load_file(file_path, default=None, mode=None):
if not os.path.isfile(file_path):
return default
if not mode:
mode = 'r'
with open(file_path, mode) as f:
result = f.read()
return result
def to_str(obj):
""" Convert a string/bytes object to a string """
if not obj or isinstance(obj, six.string_types):
return obj
return obj.decode(DEFAULT_ENCODING)
def to_bytes(obj):
""" Convert a string/bytes object to bytes """
if not isinstance(obj, six.string_types):
return obj
return obj.encode(DEFAULT_ENCODING)
def cleanup(files=True, env=ENV_DEV, quiet=True):
if files:
cleanup_tmp_files()
def cleanup_threads_and_processes(quiet=True):
for t in TMP_THREADS:
t.stop(quiet=quiet)
def cleanup_tmp_files():
for tmp in TMP_FILES:
try:
if os.path.isdir(tmp):
run('rm -rf "%s"' % tmp)
else:
os.remove(tmp)
except Exception as e:
pass # file likely doesn't exist, or permission denied
del TMP_FILES[:]
def is_zip_file(content):
import zipfile
stream = BytesIO(content)
return zipfile.is_zipfile(stream)
def is_jar_archive(content):
# TODO Simple stupid heuristic to determine whether a file is a JAR archive
try:
return 'class' in content and 'META-INF' in content
except TypeError as e:
# in Python 3 we need to use byte strings for byte-based file content
return b'class' in content and b'META-INF' in content
def is_root():
out = run('whoami').strip()
return out == 'root'
def cleanup_resources():
cleanup_tmp_files()
cleanup_threads_and_processes()
def generate_ssl_cert(target_file=None, overwrite=False, random=False):
# Note: Do NOT import "OpenSSL" at the root scope
# (Our test Lambdas are importing this file but don't have the module installed)
from OpenSSL import crypto
if random and target_file:
if '.' in target_file:
target_file = target_file.replace('.', '.%s.' % short_uid(), 1)
else:
target_file = '%s.%s' % (target_file, short_uid())
if target_file and not overwrite and os.path.exists(target_file):
return
# create a key pair
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 1024)
# create a self-signed cert
cert = crypto.X509()
cert.get_subject().C = "AU"
cert.get_subject().ST = "Some-State"
cert.get_subject().L = "Some-Locality"
cert.get_subject().O = "LocalStack Org"
cert.get_subject().OU = "Testing"
cert.get_subject().CN = "LocalStack"
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.sign(k, 'sha1')
cert_file = StringIO()
key_file = StringIO()
cert_file.write(to_str(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)))
key_file.write(to_str(crypto.dump_privatekey(crypto.FILETYPE_PEM, k)))
cert_file_content = cert_file.getvalue().strip()
key_file_content = key_file.getvalue().strip()
file_content = '%s\n%s' % (key_file_content, cert_file_content)
if target_file:
save_file(target_file, file_content)
key_file_name = '%s.key' % target_file
cert_file_name = '%s.crt' % target_file
save_file(key_file_name, key_file_content)
save_file(cert_file_name, cert_file_content)
TMP_FILES.append(target_file)
TMP_FILES.append(key_file_name)
TMP_FILES.append(cert_file_name)
if random:
return target_file, cert_file_name, key_file_name
return file_content
return file_content
def run_safe(_python_lambda, print_error=True, **kwargs):
try:
_python_lambda(**kwargs)
except Exception as e:
if print_error:
print('Unable to execute function: %s' % e)
def run(cmd, cache_duration_secs=0, print_error=True, async=False, stdin=False,
stderr=subprocess.STDOUT, outfile=None, env_vars=None, inherit_cwd=False):
# don't use subprocess module as it is not thread-safe
# http://stackoverflow.com/questions/21194380/is-subprocess-popen-not-thread-safe
# import subprocess
if six.PY2:
import subprocess32 as subprocess
else:
import subprocess
env_dict = os.environ.copy()
if env_vars:
env_dict.update(env_vars)
def do_run(cmd):
try:
cwd = os.getcwd() if inherit_cwd else None
if not async:
if stdin:
return subprocess.check_output(cmd, shell=True,
stderr=stderr, stdin=subprocess.PIPE, env=env_dict, cwd=cwd)
output = subprocess.check_output(cmd, shell=True, stderr=stderr, env=env_dict, cwd=cwd)
return output.decode(DEFAULT_ENCODING)
# subprocess.Popen is not thread-safe, hence use a mutex here..
try:
mutex_popen.acquire()
stdin_arg = subprocess.PIPE if stdin else None
stdout_arg = open(outfile, 'wb') if isinstance(outfile, six.string_types) else outfile
process = subprocess.Popen(cmd, shell=True, stdin=stdin_arg, bufsize=-1,
stderr=stderr, stdout=stdout_arg, env=env_dict, cwd=cwd)
return process
finally:
mutex_popen.release()
except subprocess.CalledProcessError as e:
if print_error:
print("ERROR: '%s': %s" % (cmd, e.output))
raise e
if cache_duration_secs <= 0:
return do_run(cmd)
hash = md5(cmd)
cache_file = CACHE_FILE_PATTERN.replace('*', hash)
if os.path.isfile(cache_file):
# check file age
mod_time = os.path.getmtime(cache_file)
time_now = now()
if mod_time > (time_now - cache_duration_secs):
f = open(cache_file)
result = f.read()
f.close()
return result
# print("NO CACHED result available for (timeout %s): %s" % (cache_duration_secs,cmd))
result = do_run(cmd)
f = open(cache_file, 'w+')
f.write(result)
f.close()
clean_cache()
return result
def clone(item):
return json.loads(json.dumps(item))
def remove_non_ascii(text):
# text = unicode(text, "utf-8")
text = text.decode('utf-8', CODEC_HANDLER_UNDERSCORE)
# text = unicodedata.normalize('NFKD', text)
text = text.encode('ascii', CODEC_HANDLER_UNDERSCORE)
return text
class NetrcBypassAuth(requests.auth.AuthBase):
def __call__(self, r):
return r
class _RequestsSafe(type):
""" Wrapper around requests library, which prevents it from verifying
SSL certificates or reading credentials from ~/.netrc file """
def __getattr__(self, name):
method = requests.__dict__.get(name.lower())
if not method:
return method
def _missing(*args, **kwargs):
if 'auth' not in kwargs:
kwargs['auth'] = NetrcBypassAuth()
if 'verify' not in kwargs:
kwargs['verify'] = False
return method(*args, **kwargs)
return _missing
# create class-of-a-class
class safe_requests(with_metaclass(_RequestsSafe)):
pass
def make_http_request(url, data=None, headers=None, method='GET'):
if is_string(method):
method = requests.__dict__[method.lower()]
return method(url, headers=headers, data=data, auth=NetrcBypassAuth(), verify=False)
def clean_cache(file_pattern=CACHE_FILE_PATTERN,
last_clean_time=last_cache_clean_time, max_age=CACHE_MAX_AGE):
mutex_clean.acquire()
time_now = now()
try:
if last_clean_time['time'] > time_now - CACHE_CLEAN_TIMEOUT:
return
for cache_file in set(glob.glob(file_pattern)):
mod_time = os.path.getmtime(cache_file)
if time_now > mod_time + max_age:
rm_rf(cache_file)
last_clean_time['time'] = time_now
finally:
mutex_clean.release()
return time_now
def truncate(data, max_length=100):
return (data[:max_length] + '...') if len(data) > max_length else data
def parallelize(func, list, size=None):
if not size:
size = len(list)
if size <= 0:
return None
pool = Pool(size)
result = pool.map(func, list)
pool.close()
pool.join()
return result