atlassian/themis

View on GitHub
themis/util/common.py

Summary

Maintainability
C
1 day
Test Coverage
import threading
import subprocess32 as subprocess
import os
import re
import time
import urllib
import glob
import json
import math
import uuid
import logging
import decimal
import pyhive.presto
from datetime import datetime, timedelta
from collections import namedtuple

CACHE_CLEAN_TIMEOUT = 60 * 5
CACHE_MAX_AGE = 60 * 60
CACHE_FILE_PATTERN = '/tmp/cache.*.json'

# connect timeout for curl commands
CURL_CONNECT_TIMEOUT = 3

# cache query results
QUERY_CACHE_TIMEOUT = 60
GANGLIA_CACHE_TIMEOUT = 60
STATIC_INFO_CACHE_TIMEOUT = 60 * 30

# cache globals
last_cache_clean_time = 0
mutex_clean = threading.RLock()
mutex_popen = threading.RLock()


def get_logger(name=None):
    log = logging.getLogger(name)
    return log

# logger
LOG = get_logger(__name__)


class FuncThread(threading.Thread):
    def __init__(self, func, params, quiet=False):
        threading.Thread.__init__(self)
        self.daemon = True
        self.params = params
        self.func = func
        self.quiet = quiet

    def run(self):
        try:
            self.func(self.params)
        except Exception, e:
            if not self.quiet:
                print("Thread run method %s(%s) failed: %s" %
                    (self.func, self.params, traceback.format_exc()))

    def stop(self, quiet=False):
        if not quiet and not self.quiet:
            print("WARN: not implemented: FuncThread.stop(..)")


def clean_cache():
    global last_cache_clean_time
    mutex_clean.acquire()
    try:
        time_now = now()
        if last_cache_clean_time > time_now - CACHE_CLEAN_TIMEOUT:
            return
        for cache_file in set(glob.glob(CACHE_FILE_PATTERN)):
            mod_time = os.path.getmtime(cache_file)
            if time_now > mod_time + CACHE_MAX_AGE:
                os.remove(cache_file)
        last_cache_clean_time = time_now
    finally:
        mutex_clean.release()


def setup_logging(log_file=None, format='%(asctime)s %(levelname)s: %(name)s: %(message)s'):
    logging.getLogger('werkzeug').setLevel(logging.WARNING)
    logging.getLogger('requests').setLevel(logging.WARNING)
    logging.getLogger('botocore').setLevel(logging.WARNING)
    if log_file:
        logging.basicConfig(filename=log_file, level=logging.INFO, format=format)
        formatter = logging.Formatter(format)
        handler = logging.StreamHandler()
        handler.setFormatter(formatter)
        logging.getLogger().addHandler(handler)
    else:
        logging.basicConfig(level=logging.INFO, format=format)


def now():
    return time.mktime(datetime.now().timetuple())


def json_namedtuple(json_string):
    return json.loads(json_string, object_hook=lambda d: namedtuple('X', d.keys())(*d.values()))


def load_json_file(file, default=None):
    if not os.path.isfile(file):
        return default
    f = open(file)
    result = json.loads(f.read())
    f.close()
    return result


def save_file(file, content):
    f = open(file, 'w+')
    f.write(content)
    f.close()


def save_json_file(file, content):
    save_file(file, json.dumps(content))


def json_defaults(obj):
    if isinstance(obj, decimal.Decimal):
        if obj % 1 > 0:
            return float(obj)
        else:
            return long(obj)
    if isinstance(obj, datetime):
        TIMESTAMP_FORMAT_MS = '%Y-%m-%dT%H:%M:%S.%fZ'
        return obj.strftime(TIMESTAMP_FORMAT_MS)
    return obj


def json_fix(data):
    """
    Fix common JSON encoding issues. E.g., if a dict contains decimal.Decimal
    values, it cannot be dumped as JSON. This method creates a copy of the given
    dict and returns a cleaned-up version that should be JSON serializable.
    """
    return json.loads(json.dumps(data, default=json_defaults))


def is_composite(o):
    return isinstance(o, list) or isinstance(o, dict)


def is_float(f):
    return isinstance(f, float) and not math.isnan(f)


def remove_lines_from_string(s, regex):
    return '\n'.join([line for line in s.split('\n') if not re.match(regex, line)])


def is_number(s):
    try:
        float(s)
        return True
    except Exception:
        return False


def is_ip_address(s):
    return re.match(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', s)


def is_NaN(obj, expect_only_numbers=False):
    if expect_only_numbers and not is_number(obj):
        return True
    if obj == 'NaN' or (isinstance(obj, float) and
            math.isnan(obj) or obj in [float('Inf'), -float('Inf')]):
        return True
    return False


def remove_NaN(obj, delete_values=True, replacement='NaN', expect_only_numbers=False):
    if isinstance(obj, list):
        i = 0
        while i < len(obj):
            if is_composite(obj[i]):
                remove_NaN(obj[i], delete_values, replacement, expect_only_numbers)
            elif is_NaN(obj[i], expect_only_numbers):
                if delete_values:
                    del obj[i]
                    i -= 1
                else:
                    obj[i] = replacement
            i += 1
    elif isinstance(obj, dict):
        for key in list(obj.keys()):
            if is_composite(obj[key]):
                remove_NaN(obj[key], delete_values, replacement, expect_only_numbers)
            elif is_NaN(obj[key], expect_only_numbers):
                if delete_values:
                    del obj[key]
                else:
                    obj[key] = replacement
    return obj


def short_uid():
    return str(uuid.uuid4())[0:8]


def inject_aws_endpoint(cmd):
    try:
        if not os.environ.AWS_ENDPOINT_URL:
            return cmd
        regex = r'^aws ([^\s]+) ([^\s]+)(.*)$'
        if re.match(regex, cmd):
            cmd = re.sub(regex, r'aws --endpoint-url="%s/\1/\2" \1 \2\3' % os.environ.AWS_ENDPOINT_URL, cmd)
    except AttributeError, e:
        pass
    return cmd


def inject_env_vars(s):
    regex = r'\$([a-zA-Z0-9_]+)'
    match = re.search(regex, s, re.MULTILINE)
    while match:
        var_name = match.group(1)
        s = s.replace('$%s' % var_name, os.environ.get(var_name, ''))
        match = re.search(regex, s, re.MULTILINE)
    return s


def run_func(func, cache_duration_secs=0, **kwargs):
    return run_cached(func, cache_duration_secs=cache_duration_secs, **kwargs)


def run_cached(func, cache_duration_secs=0, **kwargs):
    if cache_duration_secs <= 0:
        return func(**kwargs)
    hash = md5(func.__name__ + str(kwargs))
    cache_file = CACHE_FILE_PATTERN.replace('*', '%s') % hash
    if os.path.isfile(cache_file):
        # check file age
        mod_time = os.path.getmtime(cache_file)
        time_now = now()
        if mod_time > (time_now - cache_duration_secs):
            f = open(cache_file)
            result = f.read()
            f.close()
            return result
    result = func(**kwargs)
    f = open(cache_file, 'w+')
    if not isinstance(result, basestring):
        try:
            result = json.dumps(result)
        except Exception, e:
            result = json_fix(result)
            result = json.dumps(result)
    f.write(result)
    f.close()
    clean_cache()
    return result


def run(cmd, cache_duration_secs=0, log_error=False, retries=0, sleep=2, backoff=1.4):
    def do_run(cmd):
        try:
            mutex_popen.acquire()
            # process = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT)
            process = subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
            mutex_popen.release()
            output = ''
            for line in iter(process.stdout.readline, ''):
                output += line
            out, err = process.communicate()
            if process.returncode != 0:
                raise subprocess.CalledProcessError(process.returncode, cmd, output=output)
            return output
        except subprocess.CalledProcessError, e:
            if log_error:
                LOG.error("%s" % e.output)
            if retries > 0:
                LOG.info("INFO: Re-running command '%s'" % cmd)
                time.sleep(sleep)
                return run(cmd, cache_duration_secs, log_error, retries - 1, sleep * backoff, backoff)
            raise e
    cmd = inject_aws_endpoint(cmd)
    kwargs = {'cmd': cmd}
    return run_cached(do_run, cache_duration_secs=cache_duration_secs, **kwargs)


def md5(string):
    import hashlib
    m = hashlib.md5()
    m.update(string)
    return m.hexdigest()


def array_reverse(array):
    result = []
    for i, item1 in enumerate(array):
        for j, item2 in enumerate(item1):
            if len(result) <= j:
                result.append([])
            result[j].append(item2)
    return result


def apply_2dim(array, function):
    for item in array:
        function(item)
    return array


def parallelize(array_or_dict, func):
    class MyThread (threading.Thread):
        def __init__(self, item, key=None):
            threading.Thread.__init__(self)
            self.item = item
            self.key = key

        def run(self):
            if self.key:
                func(self.key, self.item)
            else:
                func(self.item)
    threads = []
    if isinstance(array_or_dict, list):
        for item in array_or_dict:
            t = MyThread(item)
            t.start()
            threads.append(t)
    elif isinstance(array_or_dict, dict):
        for key, item in array_or_dict.iteritems():
            t = MyThread(item, key)
            t.start()
            threads.append(t)
    else:
        raise Exception("Expected either array or dict")
    for t in threads:
        t.join()


def get_start_and_end(diff_secs, format="%m/%d/%Y %H:%M", escape=True):
    end_time = datetime.utcnow()
    start_time = (end_time + timedelta(seconds=-diff_secs))
    if isinstance(format, basestring):
        start_time = start_time.strftime(format)
        end_time = end_time.strftime(format)
        if escape:
            start_time = urllib.quote_plus(start_time)
            end_time = urllib.quote_plus(end_time)
    return [start_time, end_time]


def run_presto_query(presto_sql, hostname, port=8081):
    if presto_sql != "" and presto_sql is not None:
        cursor = pyhive.presto.connect(hostname, port).cursor()
        cursor.execute(presto_sql)
    else:
        raise Exception("Invalid Presto query: '%s'" % presto_sql)
    return cursor.fetchall()